langroid 0.58.2__py3-none-any.whl → 0.59.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. langroid/agent/base.py +39 -17
  2. langroid/agent/base.py-e +2216 -0
  3. langroid/agent/callbacks/chainlit.py +2 -1
  4. langroid/agent/chat_agent.py +73 -55
  5. langroid/agent/chat_agent.py-e +2086 -0
  6. langroid/agent/chat_document.py +7 -7
  7. langroid/agent/chat_document.py-e +513 -0
  8. langroid/agent/openai_assistant.py +9 -9
  9. langroid/agent/openai_assistant.py-e +882 -0
  10. langroid/agent/special/arangodb/arangodb_agent.py +10 -18
  11. langroid/agent/special/arangodb/arangodb_agent.py-e +648 -0
  12. langroid/agent/special/arangodb/tools.py +3 -3
  13. langroid/agent/special/doc_chat_agent.py +16 -14
  14. langroid/agent/special/lance_rag/critic_agent.py +2 -2
  15. langroid/agent/special/lance_rag/query_planner_agent.py +4 -4
  16. langroid/agent/special/lance_tools.py +6 -5
  17. langroid/agent/special/lance_tools.py-e +61 -0
  18. langroid/agent/special/neo4j/neo4j_chat_agent.py +3 -7
  19. langroid/agent/special/neo4j/neo4j_chat_agent.py-e +430 -0
  20. langroid/agent/special/relevance_extractor_agent.py +1 -1
  21. langroid/agent/special/sql/sql_chat_agent.py +11 -3
  22. langroid/agent/task.py +9 -87
  23. langroid/agent/task.py-e +2418 -0
  24. langroid/agent/tool_message.py +33 -17
  25. langroid/agent/tool_message.py-e +400 -0
  26. langroid/agent/tools/file_tools.py +4 -2
  27. langroid/agent/tools/file_tools.py-e +234 -0
  28. langroid/agent/tools/mcp/fastmcp_client.py +19 -6
  29. langroid/agent/tools/mcp/fastmcp_client.py-e +584 -0
  30. langroid/agent/tools/orchestration.py +22 -17
  31. langroid/agent/tools/orchestration.py-e +301 -0
  32. langroid/agent/tools/recipient_tool.py +3 -3
  33. langroid/agent/tools/task_tool.py +22 -16
  34. langroid/agent/tools/task_tool.py-e +249 -0
  35. langroid/agent/xml_tool_message.py +90 -35
  36. langroid/agent/xml_tool_message.py-e +392 -0
  37. langroid/cachedb/base.py +1 -1
  38. langroid/embedding_models/base.py +2 -2
  39. langroid/embedding_models/models.py +3 -7
  40. langroid/embedding_models/models.py-e +563 -0
  41. langroid/exceptions.py +4 -1
  42. langroid/language_models/azure_openai.py +2 -2
  43. langroid/language_models/azure_openai.py-e +134 -0
  44. langroid/language_models/base.py +6 -4
  45. langroid/language_models/base.py-e +812 -0
  46. langroid/language_models/client_cache.py +64 -0
  47. langroid/language_models/config.py +2 -4
  48. langroid/language_models/config.py-e +18 -0
  49. langroid/language_models/model_info.py +9 -1
  50. langroid/language_models/model_info.py-e +483 -0
  51. langroid/language_models/openai_gpt.py +119 -20
  52. langroid/language_models/openai_gpt.py-e +2280 -0
  53. langroid/language_models/provider_params.py +3 -22
  54. langroid/language_models/provider_params.py-e +153 -0
  55. langroid/mytypes.py +11 -4
  56. langroid/mytypes.py-e +132 -0
  57. langroid/parsing/code_parser.py +1 -1
  58. langroid/parsing/file_attachment.py +1 -1
  59. langroid/parsing/file_attachment.py-e +246 -0
  60. langroid/parsing/md_parser.py +14 -4
  61. langroid/parsing/md_parser.py-e +574 -0
  62. langroid/parsing/parser.py +22 -7
  63. langroid/parsing/parser.py-e +410 -0
  64. langroid/parsing/repo_loader.py +3 -1
  65. langroid/parsing/repo_loader.py-e +812 -0
  66. langroid/parsing/search.py +1 -1
  67. langroid/parsing/url_loader.py +17 -51
  68. langroid/parsing/url_loader.py-e +683 -0
  69. langroid/parsing/urls.py +5 -4
  70. langroid/parsing/urls.py-e +279 -0
  71. langroid/prompts/prompts_config.py +1 -1
  72. langroid/pydantic_v1/__init__.py +45 -6
  73. langroid/pydantic_v1/__init__.py-e +36 -0
  74. langroid/pydantic_v1/main.py +11 -4
  75. langroid/pydantic_v1/main.py-e +11 -0
  76. langroid/utils/configuration.py +13 -11
  77. langroid/utils/configuration.py-e +141 -0
  78. langroid/utils/constants.py +1 -1
  79. langroid/utils/constants.py-e +32 -0
  80. langroid/utils/globals.py +21 -5
  81. langroid/utils/globals.py-e +49 -0
  82. langroid/utils/html_logger.py +2 -1
  83. langroid/utils/html_logger.py-e +825 -0
  84. langroid/utils/object_registry.py +1 -1
  85. langroid/utils/object_registry.py-e +66 -0
  86. langroid/utils/pydantic_utils.py +55 -28
  87. langroid/utils/pydantic_utils.py-e +602 -0
  88. langroid/utils/types.py +2 -2
  89. langroid/utils/types.py-e +113 -0
  90. langroid/vector_store/base.py +3 -3
  91. langroid/vector_store/lancedb.py +5 -5
  92. langroid/vector_store/lancedb.py-e +404 -0
  93. langroid/vector_store/meilisearch.py +2 -2
  94. langroid/vector_store/pineconedb.py +4 -4
  95. langroid/vector_store/pineconedb.py-e +427 -0
  96. langroid/vector_store/postgres.py +1 -1
  97. langroid/vector_store/qdrantdb.py +3 -3
  98. langroid/vector_store/weaviatedb.py +1 -1
  99. {langroid-0.58.2.dist-info → langroid-0.59.0b1.dist-info}/METADATA +3 -2
  100. langroid-0.59.0b1.dist-info/RECORD +181 -0
  101. langroid/agent/special/doc_chat_task.py +0 -0
  102. langroid/mcp/__init__.py +0 -1
  103. langroid/mcp/server/__init__.py +0 -1
  104. langroid-0.58.2.dist-info/RECORD +0 -145
  105. {langroid-0.58.2.dist-info → langroid-0.59.0b1.dist-info}/WHEEL +0 -0
  106. {langroid-0.58.2.dist-info → langroid-0.59.0b1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,2086 @@
1
+ import copy
2
+ import inspect
3
+ import json
4
+ import logging
5
+ import textwrap
6
+ from contextlib import ExitStack
7
+ from inspect import isclass
8
+ from typing import Any, Dict, List, Optional, Self, Set, Tuple, Type, Union, cast
9
+
10
+ import openai
11
+ from rich import print
12
+ from rich.console import Console
13
+ from rich.markup import escape
14
+
15
+ from langroid.agent.base import Agent, AgentConfig, async_noop_fn, noop_fn
16
+ from langroid.agent.chat_document import ChatDocument
17
+ from langroid.agent.tool_message import (
18
+ ToolMessage,
19
+ format_schema_for_strict,
20
+ )
21
+ from langroid.agent.xml_tool_message import XMLToolMessage
22
+ from langroid.language_models.base import (
23
+ LLMFunctionCall,
24
+ LLMFunctionSpec,
25
+ LLMMessage,
26
+ LLMResponse,
27
+ OpenAIJsonSchemaSpec,
28
+ OpenAIToolSpec,
29
+ Role,
30
+ StreamingIfAllowed,
31
+ ToolChoiceTypes,
32
+ )
33
+ from langroid.language_models.openai_gpt import OpenAIGPT
34
+ from langroid.mytypes import Entity, NonToolAction
35
+ from pydantic import BaseModel, ValidationError
36
+ from langroid.utils.configuration import settings
37
+ from langroid.utils.object_registry import ObjectRegistry
38
+ from langroid.utils.output import status
39
+ from langroid.utils.pydantic_utils import PydanticWrapper, get_pydantic_wrapper
40
+ from langroid.utils.types import is_callable
41
+
42
+ console = Console()
43
+
44
+ logger = logging.getLogger(__name__)
45
+
46
+
47
+ class ChatAgentConfig(AgentConfig):
48
+ """
49
+ Configuration for ChatAgent
50
+
51
+ Attributes:
52
+ system_message: system message to include in message sequence
53
+ (typically defines role and task of agent).
54
+ Used only if `task` is not specified in the constructor.
55
+ user_message: user message to include in message sequence.
56
+ Used only if `task` is not specified in the constructor.
57
+ use_tools: whether to use our own ToolMessages mechanism
58
+ handle_llm_no_tool (Any): desired agent_response when
59
+ LLM generates non-tool msg.
60
+ use_functions_api: whether to use functions/tools native to the LLM API
61
+ (e.g. OpenAI's `function_call` or `tool_call` mechanism)
62
+ use_tools_api: When `use_functions_api` is True, if this is also True,
63
+ the OpenAI tool-call API is used, rather than the older/deprecated
64
+ function-call API. However the tool-call API has some tricky aspects,
65
+ hence we set this to False by default.
66
+ strict_recovery: whether to enable strict schema recovery when there
67
+ is a tool-generation error.
68
+ enable_orchestration_tool_handling: whether to enable handling of orchestration
69
+ tools, e.g. ForwardTool, DoneTool, PassTool, etc.
70
+ output_format: When supported by the LLM (certain OpenAI LLMs
71
+ and local LLMs served by providers such as vLLM), ensures
72
+ that the output is a JSON matching the corresponding
73
+ schema via grammar-based decoding
74
+ handle_output_format: When `output_format` is a `ToolMessage` T,
75
+ controls whether T is "enabled for handling".
76
+ use_output_format: When `output_format` is a `ToolMessage` T,
77
+ controls whether T is "enabled for use" (by LLM) and
78
+ instructions on using T are added to the system message.
79
+ instructions_output_format: Controls whether we generate instructions for
80
+ `output_format` in the system message.
81
+ use_tools_on_output_format: Controls whether to automatically switch
82
+ to the Langroid-native tools mechanism when `output_format` is set.
83
+ Note that LLMs may generate tool calls which do not belong to
84
+ `output_format` even when strict JSON mode is enabled, so this should be
85
+ enabled when such tool calls are not desired.
86
+ output_format_include_defaults: Whether to include fields with default arguments
87
+ in the output schema
88
+ full_citations: Whether to show source reference citation + content for each
89
+ citation, or just the main reference citation.
90
+ """
91
+
92
+ system_message: str = "You are a helpful assistant."
93
+ user_message: Optional[str] = None
94
+ handle_llm_no_tool: Any = None
95
+ use_tools: bool = True
96
+ use_functions_api: bool = False
97
+ use_tools_api: bool = True
98
+ strict_recovery: bool = True
99
+ enable_orchestration_tool_handling: bool = True
100
+ output_format: Optional[type] = None
101
+ handle_output_format: bool = True
102
+ use_output_format: bool = True
103
+ instructions_output_format: bool = True
104
+ output_format_include_defaults: bool = True
105
+ use_tools_on_output_format: bool = True
106
+ full_citations: bool = True # show source + content for each citation?
107
+
108
+ def _set_fn_or_tools(self) -> None:
109
+ """
110
+ Enable Langroid Tool or OpenAI-like fn-calling,
111
+ depending on config settings.
112
+ """
113
+ if not self.use_functions_api or not self.use_tools:
114
+ return
115
+ if self.use_functions_api and self.use_tools:
116
+ logger.debug(
117
+ """
118
+ You have enabled both `use_tools` and `use_functions_api`.
119
+ Setting `use_functions_api` to False.
120
+ """
121
+ )
122
+ self.use_tools = True
123
+ self.use_functions_api = False
124
+
125
+
126
+ class ChatAgent(Agent):
127
+ """
128
+ Chat Agent interacting with external env
129
+ (could be human, or external tools).
130
+ The agent (the LLM actually) is provided with an optional "Task Spec",
131
+ which is a sequence of `LLMMessage`s. These are used to initialize
132
+ the `task_messages` of the agent.
133
+ In most applications we will use a `ChatAgent` rather than a bare `Agent`.
134
+ The `Agent` class mainly exists to hold various common methods and attributes.
135
+ One difference between `ChatAgent` and `Agent` is that `ChatAgent`'s
136
+ `llm_response` method uses "chat mode" API (i.e. one that takes a
137
+ message sequence rather than a single message),
138
+ whereas the same method in the `Agent` class uses "completion mode" API (i.e. one
139
+ that takes a single message).
140
+ """
141
+
142
+ def __init__(
143
+ self,
144
+ config: ChatAgentConfig = ChatAgentConfig(),
145
+ task: Optional[List[LLMMessage]] = None,
146
+ ):
147
+ """
148
+ Chat-mode agent initialized with task spec as the initial message sequence
149
+ Args:
150
+ config: settings for the agent
151
+
152
+ """
153
+ super().__init__(config)
154
+ self.config: ChatAgentConfig = config
155
+ self.config._set_fn_or_tools()
156
+ self.message_history: List[LLMMessage] = []
157
+ self.init_state()
158
+ # An agent's "task" is defined by a system msg and an optional user msg;
159
+ # These are "priming" messages that kick off the agent's conversation.
160
+ self.system_message: str = self.config.system_message
161
+ self.user_message: str | None = self.config.user_message
162
+
163
+ if task is not None:
164
+ # if task contains a system msg, we override the config system msg
165
+ if len(task) > 0 and task[0].role == Role.SYSTEM:
166
+ self.system_message = task[0].content
167
+ # if task contains a user msg, we override the config user msg
168
+ if len(task) > 1 and task[1].role == Role.USER:
169
+ self.user_message = task[1].content
170
+
171
+ # system-level instructions for using tools/functions:
172
+ # We maintain these as tools/functions are enabled/disabled,
173
+ # and whenever an LLM response is sought, these are used to
174
+ # recreate the system message (via `_create_system_and_tools_message`)
175
+ # each time, so it reflects the current set of enabled tools/functions.
176
+ # (a) these are general instructions on using certain tools/functions,
177
+ # if they are specified in a ToolMessage class as a classmethod `instructions`
178
+ self.system_tool_instructions: str = ""
179
+ # (b) these are only for the builtin in Langroid TOOLS mechanism:
180
+ self.system_tool_format_instructions: str = ""
181
+
182
+ self.llm_functions_map: Dict[str, LLMFunctionSpec] = {}
183
+ self.llm_functions_handled: Set[str] = set()
184
+ self.llm_functions_usable: Set[str] = set()
185
+ self.llm_function_force: Optional[Dict[str, str]] = None
186
+
187
+ self.output_format: Optional[type[ToolMessage | BaseModel]] = None
188
+
189
+ self.saved_requests_and_tool_setings = self._requests_and_tool_settings()
190
+ # This variable is not None and equals a `ToolMessage` T, if and only if:
191
+ # (a) T has been set as the output_format of this agent, AND
192
+ # (b) T has been "enabled for use" ONLY for enforcing this output format, AND
193
+ # (c) T has NOT been explicitly "enabled for use" by this Agent.
194
+ self.enabled_use_output_format: Optional[type[ToolMessage]] = None
195
+ # As above but deals with "enabled for handling" instead of "enabled for use".
196
+ self.enabled_handling_output_format: Optional[type[ToolMessage]] = None
197
+ if config.output_format is not None:
198
+ self.set_output_format(config.output_format)
199
+ # instructions specifically related to enforcing `output_format`
200
+ self.output_format_instructions = ""
201
+
202
+ # controls whether to disable strict schemas for this agent if
203
+ # strict mode causes exception
204
+ self.disable_strict = False
205
+ # Tracks whether any strict tool is enabled; used to determine whether to set
206
+ # `self.disable_strict` on an exception
207
+ self.any_strict = False
208
+ # Tracks the set of tools on which we force-disable strict decoding
209
+ self.disable_strict_tools_set: set[str] = set()
210
+
211
+ if self.config.enable_orchestration_tool_handling:
212
+ # Only enable HANDLING by `agent_response`, NOT LLM generation of these.
213
+ # This is useful where tool-handlers or agent_response generate these
214
+ # tools, and need to be handled.
215
+ # We don't want enable orch tool GENERATION by default, since that
216
+ # might clutter-up the LLM system message unnecessarily.
217
+ from langroid.agent.tools.orchestration import (
218
+ AgentDoneTool,
219
+ AgentSendTool,
220
+ DonePassTool,
221
+ DoneTool,
222
+ ForwardTool,
223
+ PassTool,
224
+ ResultTool,
225
+ SendTool,
226
+ )
227
+
228
+ self.enable_message(ForwardTool, use=False, handle=True)
229
+ self.enable_message(DoneTool, use=False, handle=True)
230
+ self.enable_message(AgentDoneTool, use=False, handle=True)
231
+ self.enable_message(PassTool, use=False, handle=True)
232
+ self.enable_message(DonePassTool, use=False, handle=True)
233
+ self.enable_message(SendTool, use=False, handle=True)
234
+ self.enable_message(AgentSendTool, use=False, handle=True)
235
+ self.enable_message(ResultTool, use=False, handle=True)
236
+
237
+ def init_state(self) -> None:
238
+ """
239
+ Initialize the state of the agent. Just conversation state here,
240
+ but subclasses can override this to initialize other state.
241
+ """
242
+ super().init_state()
243
+ self.clear_history(0)
244
+ self.clear_dialog()
245
+
246
+ @staticmethod
247
+ def from_id(id: str) -> "ChatAgent":
248
+ """
249
+ Get an agent from its ID
250
+ Args:
251
+ agent_id (str): ID of the agent
252
+ Returns:
253
+ ChatAgent: The agent with the given ID
254
+ """
255
+ return cast(ChatAgent, Agent.from_id(id))
256
+
257
+ def clone(self, i: int = 0) -> "ChatAgent":
258
+ """Create i'th clone of this agent, ensuring tool use/handling is cloned.
259
+ Important: We assume all member variables are in the __init__ method here
260
+ and in the Agent class.
261
+ TODO: We are attempting to clone an agent after its state has been
262
+ changed in possibly many ways. Below is an imperfect solution. Caution advised.
263
+ Revisit later.
264
+ """
265
+ agent_cls = type(self)
266
+ config_copy = copy.deepcopy(self.config)
267
+ config_copy.name = f"{config_copy.name}-{i}"
268
+ new_agent = agent_cls(config_copy)
269
+ new_agent.system_tool_instructions = self.system_tool_instructions
270
+ new_agent.system_tool_format_instructions = self.system_tool_format_instructions
271
+ new_agent.llm_tools_map = self.llm_tools_map
272
+ new_agent.llm_functions_map = self.llm_functions_map
273
+ new_agent.llm_functions_handled = self.llm_functions_handled
274
+ new_agent.llm_functions_usable = self.llm_functions_usable
275
+ new_agent.llm_function_force = self.llm_function_force
276
+ # Caution - we are copying the vector-db, maybe we don't always want this?
277
+ new_agent.vecdb = self.vecdb
278
+ new_agent.id = ObjectRegistry.new_id()
279
+ if self.config.add_to_registry:
280
+ ObjectRegistry.register_object(new_agent)
281
+ return new_agent
282
+
283
+ def _strict_mode_for_tool(self, tool: str | type[ToolMessage]) -> bool:
284
+ """Should we enable strict mode for a given tool?"""
285
+ if isinstance(tool, str):
286
+ tool_class = self.llm_tools_map[tool]
287
+ else:
288
+ tool_class = tool
289
+ name = tool_class.default_value("request")
290
+ if name in self.disable_strict_tools_set or self.disable_strict:
291
+ return False
292
+ strict: Optional[bool] = tool_class.default_value("strict")
293
+ if strict is None:
294
+ strict = self._strict_tools_available()
295
+
296
+ return strict
297
+
298
+ def _fn_call_available(self) -> bool:
299
+ """Does this agent's LLM support function calling?"""
300
+ return self.llm is not None and self.llm.supports_functions_or_tools()
301
+
302
+ def _strict_tools_available(self) -> bool:
303
+ """Does this agent's LLM support strict tools?"""
304
+ return (
305
+ not self.disable_strict
306
+ and self.llm is not None
307
+ and isinstance(self.llm, OpenAIGPT)
308
+ and self.llm.config.parallel_tool_calls is False
309
+ and self.llm.supports_strict_tools
310
+ )
311
+
312
+ def _json_schema_available(self) -> bool:
313
+ """Does this agent's LLM support strict JSON schema output format?"""
314
+ return (
315
+ not self.disable_strict
316
+ and self.llm is not None
317
+ and isinstance(self.llm, OpenAIGPT)
318
+ and self.llm.supports_json_schema
319
+ )
320
+
321
+ def set_system_message(self, msg: str) -> None:
322
+ self.system_message = msg
323
+ if len(self.message_history) > 0:
324
+ # if there is message history, update the system message in it
325
+ self.message_history[0].content = msg
326
+
327
+ def set_user_message(self, msg: str) -> None:
328
+ self.user_message = msg
329
+
330
+ @property
331
+ def task_messages(self) -> List[LLMMessage]:
332
+ """
333
+ The task messages are the initial messages that define the task
334
+ of the agent. There will be at least a system message plus possibly a user msg.
335
+ Returns:
336
+ List[LLMMessage]: the task messages
337
+ """
338
+ msgs = [self._create_system_and_tools_message()]
339
+ if self.user_message:
340
+ msgs.append(LLMMessage(role=Role.USER, content=self.user_message))
341
+ return msgs
342
+
343
+ def _drop_msg_update_tool_calls(self, msg: LLMMessage) -> None:
344
+ id2idx = {t.id: i for i, t in enumerate(self.oai_tool_calls)}
345
+ if msg.role == Role.TOOL:
346
+ # dropping tool result, so ADD the corresponding tool-call back
347
+ # to the list of pending calls!
348
+ id = msg.tool_call_id
349
+ if id in self.oai_tool_id2call:
350
+ self.oai_tool_calls.append(self.oai_tool_id2call[id])
351
+ elif msg.tool_calls is not None:
352
+ # dropping a msg with tool-calls, so DROP these from pending list
353
+ # as well as from id -> call map
354
+ for tool_call in msg.tool_calls:
355
+ if tool_call.id in id2idx:
356
+ self.oai_tool_calls.pop(id2idx[tool_call.id])
357
+ if tool_call.id in self.oai_tool_id2call:
358
+ del self.oai_tool_id2call[tool_call.id]
359
+
360
+ def clear_history(self, start: int = -2) -> None:
361
+ """
362
+ Clear the message history, starting at the index `start`
363
+
364
+ Args:
365
+ start (int): index of first message to delete; default = -2
366
+ (i.e. delete last 2 messages, typically these
367
+ are the last user and assistant messages)
368
+ """
369
+ if start < 0:
370
+ n = len(self.message_history)
371
+ start = max(0, n + start)
372
+ dropped = self.message_history[start:]
373
+ # consider the dropped msgs in REVERSE order, so we are
374
+ # carefully updating self.oai_tool_calls
375
+ for msg in reversed(dropped):
376
+ self._drop_msg_update_tool_calls(msg)
377
+ # clear out the chat document from the ObjectRegistry
378
+ ChatDocument.delete_id(msg.chat_document_id)
379
+ self.message_history = self.message_history[:start]
380
+
381
+ def update_history(self, message: str, response: str) -> None:
382
+ """
383
+ Update the message history with the latest user message and LLM response.
384
+ Args:
385
+ message (str): user message
386
+ response: (str): LLM response
387
+ """
388
+ self.message_history.extend(
389
+ [
390
+ LLMMessage(role=Role.USER, content=message),
391
+ LLMMessage(role=Role.ASSISTANT, content=response),
392
+ ]
393
+ )
394
+
395
+ def tool_format_rules(self) -> str:
396
+ """
397
+ Specification of tool formatting rules
398
+ (typically JSON-based but can be non-JSON, e.g. XMLToolMessage),
399
+ based on the currently enabled usable `ToolMessage`s
400
+
401
+ Returns:
402
+ str: formatting rules
403
+ """
404
+ # ONLY Usable tools (i.e. LLM-generation allowed),
405
+ usable_tool_classes: List[Type[ToolMessage]] = [
406
+ t
407
+ for t in list(self.llm_tools_map.values())
408
+ if t.default_value("request") in self.llm_tools_usable
409
+ ]
410
+
411
+ if len(usable_tool_classes) == 0:
412
+ return ""
413
+ format_instructions = "\n\n".join(
414
+ [
415
+ msg_cls.format_instructions(tool=self.config.use_tools)
416
+ for msg_cls in usable_tool_classes
417
+ ]
418
+ )
419
+ # if any of the enabled classes has json_group_instructions, then use that,
420
+ # else fall back to ToolMessage.json_group_instructions
421
+ for msg_cls in usable_tool_classes:
422
+ if hasattr(msg_cls, "json_group_instructions") and callable(
423
+ getattr(msg_cls, "json_group_instructions")
424
+ ):
425
+ return msg_cls.group_format_instructions().format(
426
+ format_instructions=format_instructions
427
+ )
428
+ return ToolMessage.group_format_instructions().format(
429
+ format_instructions=format_instructions
430
+ )
431
+
432
+ def tool_instructions(self) -> str:
433
+ """
434
+ Instructions for tools or function-calls, for enabled and usable Tools.
435
+ These are inserted into system prompt regardless of whether we are using
436
+ our own ToolMessage mechanism or the LLM's function-call mechanism.
437
+
438
+ Returns:
439
+ str: concatenation of instructions for all usable tools
440
+ """
441
+ enabled_classes: List[Type[ToolMessage]] = list(self.llm_tools_map.values())
442
+ if len(enabled_classes) == 0:
443
+ return ""
444
+ instructions = []
445
+ for msg_cls in enabled_classes:
446
+ if msg_cls.default_value("request") in self.llm_tools_usable:
447
+ class_instructions = ""
448
+ if hasattr(msg_cls, "instructions") and inspect.ismethod(
449
+ msg_cls.instructions
450
+ ):
451
+ class_instructions = msg_cls.instructions()
452
+ if (
453
+ self.config.use_tools
454
+ and hasattr(msg_cls, "langroid_tools_instructions")
455
+ and inspect.ismethod(msg_cls.langroid_tools_instructions)
456
+ ):
457
+ class_instructions += msg_cls.langroid_tools_instructions()
458
+ # example will be shown in tool_format_rules() when using TOOLs,
459
+ # so we don't need to show it here.
460
+ example = "" if self.config.use_tools else (msg_cls.usage_examples())
461
+ if example != "":
462
+ example = "EXAMPLES:\n" + example
463
+ guidance = (
464
+ ""
465
+ if class_instructions == ""
466
+ else ("GUIDANCE: " + class_instructions)
467
+ )
468
+ if guidance == "" and example == "":
469
+ continue
470
+ instructions.append(
471
+ textwrap.dedent(
472
+ f"""
473
+ TOOL: {msg_cls.default_value("request")}:
474
+ {guidance}
475
+ {example}
476
+ """.lstrip()
477
+ )
478
+ )
479
+ if len(instructions) == 0:
480
+ return ""
481
+ instructions_str = "\n\n".join(instructions)
482
+ return textwrap.dedent(
483
+ f"""
484
+ === GUIDELINES ON SOME TOOLS/FUNCTIONS USAGE ===
485
+ {instructions_str}
486
+ """.lstrip()
487
+ )
488
+
489
+ def augment_system_message(self, message: str) -> None:
490
+ """
491
+ Augment the system message with the given message.
492
+ Args:
493
+ message (str): system message
494
+ """
495
+ self.system_message += "\n\n" + message
496
+
497
+ def last_message_with_role(self, role: Role) -> LLMMessage | None:
498
+ """from `message_history`, return the last message with role `role`"""
499
+ n_role_msgs = len([m for m in self.message_history if m.role == role])
500
+ if n_role_msgs == 0:
501
+ return None
502
+ idx = self.nth_message_idx_with_role(role, n_role_msgs)
503
+ return self.message_history[idx]
504
+
505
+ def last_message_idx_with_role(self, role: Role) -> int:
506
+ """Index of last message in message_history, with specified role.
507
+ Return -1 if not found. Index = 0 is the first message in the history.
508
+ """
509
+ indices_with_role = [
510
+ i for i, m in enumerate(self.message_history) if m.role == role
511
+ ]
512
+ if len(indices_with_role) == 0:
513
+ return -1
514
+ return indices_with_role[-1]
515
+
516
+ def nth_message_idx_with_role(self, role: Role, n: int) -> int:
517
+ """Index of `n`th message in message_history, with specified role.
518
+ (n is assumed to be 1-based, i.e. 1 is the first message with that role).
519
+ Return -1 if not found. Index = 0 is the first message in the history.
520
+ """
521
+ indices_with_role = [
522
+ i for i, m in enumerate(self.message_history) if m.role == role
523
+ ]
524
+
525
+ if len(indices_with_role) < n:
526
+ return -1
527
+ return indices_with_role[n - 1]
528
+
529
+ def update_last_message(self, message: str, role: str = Role.USER) -> None:
530
+ """
531
+ Update the last message that has role `role` in the message history.
532
+ Useful when we want to replace a long user prompt, that may contain context
533
+ documents plus a question, with just the question.
534
+ Args:
535
+ message (str): new message to replace with
536
+ role (str): role of message to replace
537
+ """
538
+ if len(self.message_history) == 0:
539
+ return
540
+ # find last message in self.message_history with role `role`
541
+ for i in range(len(self.message_history) - 1, -1, -1):
542
+ if self.message_history[i].role == role:
543
+ self.message_history[i].content = message
544
+ break
545
+
546
+ def delete_last_message(self, role: str = Role.USER) -> None:
547
+ """
548
+ Delete the last message that has role `role` from the message history.
549
+ Args:
550
+ role (str): role of message to delete
551
+ """
552
+ if len(self.message_history) == 0:
553
+ return
554
+ # find last message in self.message_history with role `role`
555
+ for i in range(len(self.message_history) - 1, -1, -1):
556
+ if self.message_history[i].role == role:
557
+ self.message_history.pop(i)
558
+ break
559
+
560
+ def _create_system_and_tools_message(self) -> LLMMessage:
561
+ """
562
+ (Re-)Create the system message for the LLM of the agent,
563
+ taking into account any tool instructions that have been added
564
+ after the agent was initialized.
565
+
566
+ The system message will consist of:
567
+ (a) the system message from the `task` arg in constructor, if any,
568
+ otherwise the default system message from the config
569
+ (b) the system tool instructions, if any
570
+ (c) the system json tool instructions, if any
571
+
572
+ Returns:
573
+ LLMMessage object
574
+ """
575
+ content = self.system_message
576
+ if self.system_tool_instructions != "":
577
+ content += "\n\n" + self.system_tool_instructions
578
+ if self.system_tool_format_instructions != "":
579
+ content += "\n\n" + self.system_tool_format_instructions
580
+ if self.output_format_instructions != "":
581
+ content += "\n\n" + self.output_format_instructions
582
+
583
+ # remove leading and trailing newlines and other whitespace
584
+ return LLMMessage(role=Role.SYSTEM, content=content.strip())
585
+
586
+ def handle_message_fallback(self, msg: str | ChatDocument) -> Any:
587
+ """
588
+ Fallback method for the "no-tools" scenario, i.e., the current `msg`
589
+ (presumably emitted by the LLM) does not have any tool that the agent
590
+ can handle.
591
+ NOTE: The `msg` may contain tools but either (a) the agent is not
592
+ enabled to handle them, or (b) there's an explicit `recipient` field
593
+ in the tool that doesn't match the agent's name.
594
+
595
+ Uses the self.config.non_tool_routing to determine the action to take.
596
+
597
+ This method can be overridden by subclasses, e.g.,
598
+ to create a "reminder" message when a tool is expected but the LLM "forgot"
599
+ to generate one.
600
+
601
+ Args:
602
+ msg (str | ChatDocument): The input msg to handle
603
+ Returns:
604
+ Any: The result of the handler method
605
+ """
606
+ if (
607
+ isinstance(msg, str)
608
+ or msg.metadata.sender != Entity.LLM
609
+ or self.config.handle_llm_no_tool is None
610
+ or self.has_only_unhandled_tools(msg)
611
+ ):
612
+ return None
613
+ # we ONLY use the `handle_llm_no_tool` config option when
614
+ # the msg is from LLM and does not contain ANY tools at all.
615
+ from langroid.agent.tools.orchestration import AgentDoneTool, ForwardTool
616
+
617
+ no_tool_option = self.config.handle_llm_no_tool
618
+ if no_tool_option in list(NonToolAction):
619
+ # in case the `no_tool_option` is one of the special NonToolAction vals
620
+ match self.config.handle_llm_no_tool:
621
+ case NonToolAction.FORWARD_USER:
622
+ return ForwardTool(agent="User")
623
+ case NonToolAction.DONE:
624
+ return AgentDoneTool(content=msg.content, tools=msg.tool_messages)
625
+ elif is_callable(no_tool_option):
626
+ return no_tool_option(msg)
627
+ # Otherwise just return `no_tool_option` as is:
628
+ # This can be any string, such as a specific nudge/reminder to the LLM,
629
+ # or even something like ResultTool etc.
630
+ return no_tool_option
631
+
632
+ def unhandled_tools(self) -> set[str]:
633
+ """The set of tools that are known but not handled.
634
+ Useful in task flow: an agent can refuse to accept an incoming msg
635
+ when it only has unhandled tools.
636
+ """
637
+ return self.llm_tools_known - self.llm_tools_handled
638
+
639
+ def enable_message(
640
+ self,
641
+ message_class: Optional[Type[ToolMessage] | List[Type[ToolMessage]]],
642
+ use: bool = True,
643
+ handle: bool = True,
644
+ force: bool = False,
645
+ require_recipient: bool = False,
646
+ include_defaults: bool = True,
647
+ ) -> None:
648
+ """
649
+ Add the tool (message class) to the agent, and enable either
650
+ - tool USE (i.e. the LLM can generate JSON to use this tool),
651
+ - tool HANDLING (i.e. the agent can handle JSON from this tool),
652
+
653
+ Args:
654
+ message_class: The ToolMessage class OR List of such classes to enable,
655
+ for USE, or HANDLING, or both.
656
+ If this is a list of ToolMessage classes, then the remain args are
657
+ applied to all classes.
658
+ Optional; if None, then apply the enabling to all tools in the
659
+ agent's toolset that have been enabled so far.
660
+ use: IF True, allow the agent (LLM) to use this tool (or all tools),
661
+ else disallow
662
+ handle: if True, allow the agent (LLM) to handle (i.e. respond to) this
663
+ tool (or all tools)
664
+ force: whether to FORCE the agent (LLM) to USE the specific
665
+ tool represented by `message_class`.
666
+ `force` is ignored if `message_class` is None.
667
+ require_recipient: whether to require that recipient be specified
668
+ when using the tool message (only applies if `use` is True).
669
+ include_defaults: whether to include fields that have default values,
670
+ in the "properties" section of the JSON format instructions.
671
+ (Normally the OpenAI completion API ignores these fields,
672
+ but the Assistant fn-calling seems to pay attn to these,
673
+ and if we don't want this, we should set this to False.)
674
+ """
675
+ if message_class is not None and isinstance(message_class, list):
676
+ for mc in message_class:
677
+ self.enable_message(
678
+ mc,
679
+ use=use,
680
+ handle=handle,
681
+ force=force,
682
+ require_recipient=require_recipient,
683
+ include_defaults=include_defaults,
684
+ )
685
+ return None
686
+ if require_recipient and message_class is not None:
687
+ message_class = message_class.require_recipient()
688
+ if isinstance(message_class, XMLToolMessage):
689
+ # XMLToolMessage is not compatible with OpenAI's Tools/functions API,
690
+ # so we disable use of functions API, enable langroid-native Tools,
691
+ # which are prompt-based.
692
+ self.config.use_functions_api = False
693
+ self.config.use_tools = True
694
+ super().enable_message_handling(message_class) # enables handling only
695
+ tools = self._get_tool_list(message_class)
696
+ if message_class is not None:
697
+ request = message_class.default_value("request")
698
+ if request == "":
699
+ raise ValueError(
700
+ f"""
701
+ ToolMessage class {message_class} must have a non-empty
702
+ 'request' field if it is to be enabled as a tool.
703
+ """
704
+ )
705
+ llm_function = message_class.llm_function_schema(defaults=include_defaults)
706
+ self.llm_functions_map[request] = llm_function
707
+ if force:
708
+ self.llm_function_force = dict(name=request)
709
+ else:
710
+ self.llm_function_force = None
711
+
712
+ for t in tools:
713
+ self.llm_tools_known.add(t)
714
+
715
+ if handle:
716
+ self.llm_tools_handled.add(t)
717
+ self.llm_functions_handled.add(t)
718
+
719
+ if (
720
+ self.enabled_handling_output_format is not None
721
+ and self.enabled_handling_output_format.name() == t
722
+ ):
723
+ # `t` was designated as "enabled for handling" ONLY for
724
+ # output_format enforcement, but we are explicitly ]
725
+ # enabling it for handling here, so we set the variable to None.
726
+ self.enabled_handling_output_format = None
727
+ else:
728
+ self.llm_tools_handled.discard(t)
729
+ self.llm_functions_handled.discard(t)
730
+
731
+ if use:
732
+ tool_class = self.llm_tools_map[t]
733
+ if tool_class._allow_llm_use:
734
+ self.llm_tools_usable.add(t)
735
+ self.llm_functions_usable.add(t)
736
+ else:
737
+ logger.warning(
738
+ f"""
739
+ ToolMessage class {tool_class} does not allow LLM use,
740
+ because `_allow_llm_use=False` either in the Tool or a
741
+ parent class of this tool;
742
+ so not enabling LLM use for this tool!
743
+ If you intended an LLM to use this tool,
744
+ set `_allow_llm_use=True` when you define the tool.
745
+ """
746
+ )
747
+ if (
748
+ self.enabled_use_output_format is not None
749
+ and self.enabled_use_output_format.default_value("request") == t
750
+ ):
751
+ # `t` was designated as "enabled for use" ONLY for output_format
752
+ # enforcement, but we are explicitly enabling it for use here,
753
+ # so we set the variable to None.
754
+ self.enabled_use_output_format = None
755
+ else:
756
+ self.llm_tools_usable.discard(t)
757
+ self.llm_functions_usable.discard(t)
758
+
759
+ self._update_tool_instructions()
760
+
761
+ def _update_tool_instructions(self) -> None:
762
+ # Set tool instructions and JSON format instructions,
763
+ # in case Tools have been enabled/disabled.
764
+ if self.config.use_tools:
765
+ self.system_tool_format_instructions = self.tool_format_rules()
766
+ self.system_tool_instructions = self.tool_instructions()
767
+
768
+ def _requests_and_tool_settings(self) -> tuple[Optional[set[str]], bool, bool]:
769
+ """
770
+ Returns the current set of enabled requests for inference and tools configs.
771
+ Used for restoring setings overriden by `set_output_format`.
772
+ """
773
+ return (
774
+ self.enabled_requests_for_inference,
775
+ self.config.use_functions_api,
776
+ self.config.use_tools,
777
+ )
778
+
779
+ @property
780
+ def all_llm_tools_known(self) -> set[str]:
781
+ """All known tools; we include `output_format` if it is a `ToolMessage`."""
782
+ known = self.llm_tools_known
783
+
784
+ if self.output_format is not None and issubclass(
785
+ self.output_format, ToolMessage
786
+ ):
787
+ return known.union({self.output_format.default_value("request")})
788
+
789
+ return known
790
+
791
+ def set_output_format(
792
+ self,
793
+ output_type: Optional[type],
794
+ force_tools: Optional[bool] = None,
795
+ use: Optional[bool] = None,
796
+ handle: Optional[bool] = None,
797
+ instructions: Optional[bool] = None,
798
+ is_copy: bool = False,
799
+ ) -> None:
800
+ """
801
+ Sets `output_format` to `output_type` and, if `force_tools` is enabled,
802
+ switches to the native Langroid tools mechanism to ensure that no tool
803
+ calls not of `output_type` are generated. By default, `force_tools`
804
+ follows the `use_tools_on_output_format` parameter in the config.
805
+
806
+ If `output_type` is None, restores to the state prior to setting
807
+ `output_format`.
808
+
809
+ If `use`, we enable use of `output_type` when it is a subclass
810
+ of `ToolMesage`. Note that this primarily controls instruction
811
+ generation: the model will always generate `output_type` regardless
812
+ of whether `use` is set. Defaults to the `use_output_format`
813
+ parameter in the config. Similarly, handling of `output_type` is
814
+ controlled by `handle`, which defaults to the
815
+ `handle_output_format` parameter in the config.
816
+
817
+ `instructions` controls whether we generate instructions specifying
818
+ the output format schema. Defaults to the `instructions_output_format`
819
+ parameter in the config.
820
+
821
+ `is_copy` is set when called via `__getitem__`. In that case, we must
822
+ copy certain fields to ensure that we do not overwrite the main agent's
823
+ setings.
824
+ """
825
+ # Disable usage of an output format which was not specifically enabled
826
+ # by `enable_message`
827
+ if self.enabled_use_output_format is not None:
828
+ self.disable_message_use(self.enabled_use_output_format)
829
+ self.enabled_use_output_format = None
830
+
831
+ # Disable handling of an output format which did not specifically have
832
+ # handling enabled via `enable_message`
833
+ if self.enabled_handling_output_format is not None:
834
+ self.disable_message_handling(self.enabled_handling_output_format)
835
+ self.enabled_handling_output_format = None
836
+
837
+ # Reset any previous instructions
838
+ self.output_format_instructions = ""
839
+
840
+ if output_type is None:
841
+ self.output_format = None
842
+ (
843
+ requests_for_inference,
844
+ use_functions_api,
845
+ use_tools,
846
+ ) = self.saved_requests_and_tool_setings
847
+ self.config = self.config.model_copy()
848
+ self.enabled_requests_for_inference = requests_for_inference
849
+ self.config.use_functions_api = use_functions_api
850
+ self.config.use_tools = use_tools
851
+ else:
852
+ if force_tools is None:
853
+ force_tools = self.config.use_tools_on_output_format
854
+
855
+ if not any(
856
+ (isclass(output_type) and issubclass(output_type, t))
857
+ for t in [ToolMessage, BaseModel]
858
+ ):
859
+ output_type = get_pydantic_wrapper(output_type)
860
+
861
+ if self.output_format is None and force_tools:
862
+ self.saved_requests_and_tool_setings = (
863
+ self._requests_and_tool_settings()
864
+ )
865
+
866
+ self.output_format = output_type
867
+ if issubclass(output_type, ToolMessage):
868
+ name = output_type.default_value("request")
869
+ if use is None:
870
+ use = self.config.use_output_format
871
+
872
+ if handle is None:
873
+ handle = self.config.handle_output_format
874
+
875
+ if use or handle:
876
+ is_usable = name in self.llm_tools_usable.union(
877
+ self.llm_functions_usable
878
+ )
879
+ is_handled = name in self.llm_tools_handled.union(
880
+ self.llm_functions_handled
881
+ )
882
+
883
+ if is_copy:
884
+ if use:
885
+ # We must copy `llm_tools_usable` so the base agent
886
+ # is unmodified
887
+ self.llm_tools_usable = copy.model_copy(self.llm_tools_usable)
888
+ self.llm_functions_usable = copy.model_copy(
889
+ self.llm_functions_usable
890
+ )
891
+ if handle:
892
+ # If handling the tool, do the same for `llm_tools_handled`
893
+ self.llm_tools_handled = copy.model_copy(self.llm_tools_handled)
894
+ self.llm_functions_handled = copy.model_copy(
895
+ self.llm_functions_handled
896
+ )
897
+ # Enable `output_type`
898
+ self.enable_message(
899
+ output_type,
900
+ # Do not override existing settings
901
+ use=use or is_usable,
902
+ handle=handle or is_handled,
903
+ )
904
+
905
+ # If the `output_type` ToilMessage was not already enabled for
906
+ # use, this means we are ONLY enabling it for use specifically
907
+ # for enforcing this output format, so we set the
908
+ # `enabled_use_output_forma to this output_type, to
909
+ # record that it should be disabled when `output_format` is changed
910
+ if not is_usable:
911
+ self.enabled_use_output_format = output_type
912
+
913
+ # (same reasoning as for use-enabling)
914
+ if not is_handled:
915
+ self.enabled_handling_output_format = output_type
916
+
917
+ generated_tool_instructions = name in self.llm_tools_usable.union(
918
+ self.llm_functions_usable
919
+ )
920
+ else:
921
+ generated_tool_instructions = False
922
+
923
+ if instructions is None:
924
+ instructions = self.config.instructions_output_format
925
+ if issubclass(output_type, BaseModel) and instructions:
926
+ if generated_tool_instructions:
927
+ # Already generated tool instructions as part of "enabling for use",
928
+ # so only need to generate a reminder to use this tool.
929
+ name = cast(ToolMessage, output_type).default_value("request")
930
+ self.output_format_instructions = textwrap.dedent(
931
+ f"""
932
+ === OUTPUT FORMAT INSTRUCTIONS ===
933
+
934
+ Please provide output using the `{name}` tool/function.
935
+ """
936
+ )
937
+ else:
938
+ if issubclass(output_type, ToolMessage):
939
+ output_format_schema = output_type.llm_function_schema(
940
+ request=True,
941
+ defaults=self.config.output_format_include_defaults,
942
+ ).parameters
943
+ else:
944
+ output_format_schema = output_type.model_json_schema()
945
+
946
+ format_schema_for_strict(output_format_schema)
947
+
948
+ self.output_format_instructions = textwrap.dedent(
949
+ f"""
950
+ === OUTPUT FORMAT INSTRUCTIONS ===
951
+ Please provide output as JSON with the following schema:
952
+
953
+ {output_format_schema}
954
+ """
955
+ )
956
+
957
+ if force_tools:
958
+ if issubclass(output_type, ToolMessage):
959
+ self.enabled_requests_for_inference = {
960
+ output_type.default_value("request")
961
+ }
962
+ if self.config.use_functions_api:
963
+ self.config = self.config.model_copy()
964
+ self.config.use_functions_api = False
965
+ self.config.use_tools = True
966
+
967
+ def __getitem__(self, output_type: type) -> Self:
968
+ """
969
+ Returns a (shallow) copy of `self` with a forced output type.
970
+ """
971
+ clone = copy.model_copy(self)
972
+ clone.set_output_format(output_type, is_copy=True)
973
+ return clone
974
+
975
+ def disable_message_handling(
976
+ self,
977
+ message_class: Optional[Type[ToolMessage]] = None,
978
+ ) -> None:
979
+ """
980
+ Disable this agent from RESPONDING to a `message_class` (Tool). If
981
+ `message_class` is None, then disable this agent from responding to ALL.
982
+ Args:
983
+ message_class: The ToolMessage class to disable; Optional.
984
+ """
985
+ super().disable_message_handling(message_class)
986
+ for t in self._get_tool_list(message_class):
987
+ self.llm_tools_handled.discard(t)
988
+ self.llm_functions_handled.discard(t)
989
+
990
+ def disable_message_use(
991
+ self,
992
+ message_class: Optional[Type[ToolMessage]],
993
+ ) -> None:
994
+ """
995
+ Disable this agent from USING a message class (Tool).
996
+ If `message_class` is None, then disable this agent from USING ALL tools.
997
+ Args:
998
+ message_class: The ToolMessage class to disable.
999
+ If None, disable all.
1000
+ """
1001
+ for t in self._get_tool_list(message_class):
1002
+ self.llm_tools_usable.discard(t)
1003
+ self.llm_functions_usable.discard(t)
1004
+
1005
+ self._update_tool_instructions()
1006
+
1007
+ def disable_message_use_except(self, message_class: Type[ToolMessage]) -> None:
1008
+ """
1009
+ Disable this agent from USING ALL messages EXCEPT a message class (Tool)
1010
+ Args:
1011
+ message_class: The only ToolMessage class to allow
1012
+ """
1013
+ request = message_class.__fields__["request"].default
1014
+ to_remove = [r for r in self.llm_tools_usable if r != request]
1015
+ for r in to_remove:
1016
+ self.llm_tools_usable.discard(r)
1017
+ self.llm_functions_usable.discard(r)
1018
+ self._update_tool_instructions()
1019
+
1020
+ def _load_output_format(self, message: ChatDocument) -> None:
1021
+ """
1022
+ If set, attempts to parse a value of type `self.output_format` from the message
1023
+ contents or any tool/function call and assigns it to `content_any`.
1024
+ """
1025
+ if self.output_format is not None:
1026
+ any_succeeded = False
1027
+ attempts: list[str | LLMFunctionCall] = [
1028
+ message.content,
1029
+ ]
1030
+
1031
+ if message.function_call is not None:
1032
+ attempts.append(message.function_call)
1033
+
1034
+ if message.oai_tool_calls is not None:
1035
+ attempts.extend(
1036
+ [
1037
+ c.function
1038
+ for c in message.oai_tool_calls
1039
+ if c.function is not None
1040
+ ]
1041
+ )
1042
+
1043
+ for attempt in attempts:
1044
+ try:
1045
+ if isinstance(attempt, str):
1046
+ content = json.loads(attempt)
1047
+ else:
1048
+ if not (
1049
+ issubclass(self.output_format, ToolMessage)
1050
+ and attempt.name
1051
+ == self.output_format.default_value("request")
1052
+ ):
1053
+ continue
1054
+
1055
+ content = attempt.arguments
1056
+
1057
+ content_any = self.output_format.model_validate(content)
1058
+
1059
+ if issubclass(self.output_format, PydanticWrapper):
1060
+ message.content_any = content_any.value # type: ignore
1061
+ else:
1062
+ message.content_any = content_any
1063
+ any_succeeded = True
1064
+ break
1065
+ except (ValidationError, json.JSONDecodeError):
1066
+ continue
1067
+
1068
+ if not any_succeeded:
1069
+ self.disable_strict = True
1070
+ logging.warning(
1071
+ """
1072
+ Validation error occured with strict output format enabled.
1073
+ Disabling strict mode.
1074
+ """
1075
+ )
1076
+
1077
+ def get_tool_messages(
1078
+ self,
1079
+ msg: str | ChatDocument | None,
1080
+ all_tools: bool = False,
1081
+ ) -> List[ToolMessage]:
1082
+ """
1083
+ Extracts messages and tracks whether any errors occurred. If strict mode
1084
+ was enabled, disables it for the tool, else triggers strict recovery.
1085
+ """
1086
+ self.tool_error = False
1087
+ most_recent_sent_by_llm = (
1088
+ len(self.message_history) > 0
1089
+ and self.message_history[-1].role == Role.ASSISTANT
1090
+ )
1091
+ was_llm = most_recent_sent_by_llm or (
1092
+ isinstance(msg, ChatDocument) and msg.metadata.sender == Entity.LLM
1093
+ )
1094
+ try:
1095
+ tools = super().get_tool_messages(msg, all_tools)
1096
+ except ValidationError as ve:
1097
+ tool_class = ve.model
1098
+ if issubclass(tool_class, ToolMessage):
1099
+ was_strict = (
1100
+ self.config.use_functions_api
1101
+ and self.config.use_tools_api
1102
+ and self._strict_mode_for_tool(tool_class)
1103
+ )
1104
+ # If the result of strict output for a tool using the
1105
+ # OpenAI tools API fails to parse, we infer that the
1106
+ # schema edits necessary for compatibility prevented
1107
+ # adherence to the underlying `ToolMessage` schema and
1108
+ # disable strict output for the tool
1109
+ if was_strict:
1110
+ name = tool_class.default_value("request")
1111
+ self.disable_strict_tools_set.add(name)
1112
+ logging.warning(
1113
+ f"""
1114
+ Validation error occured with strict tool format.
1115
+ Disabling strict mode for the {name} tool.
1116
+ """
1117
+ )
1118
+ else:
1119
+ # We will trigger the strict recovery mechanism to force
1120
+ # the LLM to correct its output, allowing us to parse
1121
+ if isinstance(msg, ChatDocument):
1122
+ self.tool_error = msg.metadata.sender == Entity.LLM
1123
+ else:
1124
+ self.tool_error = most_recent_sent_by_llm
1125
+
1126
+ if was_llm:
1127
+ raise ve
1128
+ else:
1129
+ self.tool_error = False
1130
+ return []
1131
+
1132
+ if not was_llm:
1133
+ self.tool_error = False
1134
+
1135
+ return tools
1136
+
1137
+ def _get_any_tool_message(self, optional: bool = True) -> type[ToolMessage] | None:
1138
+ """
1139
+ Returns a `ToolMessage` which wraps all enabled tools, excluding those
1140
+ where strict recovery is disabled. Used in strict recovery.
1141
+ """
1142
+ possible_tools = tuple(
1143
+ self.llm_tools_map[t]
1144
+ for t in self.llm_tools_usable
1145
+ if t not in self.disable_strict_tools_set
1146
+ )
1147
+ if len(possible_tools) == 0:
1148
+ return None
1149
+ any_tool_type = Union.__getitem__(possible_tools) # type ignore
1150
+
1151
+ maybe_optional_type = Optional[any_tool_type] if optional else any_tool_type
1152
+
1153
+ class AnyTool(ToolMessage):
1154
+ purpose: str = "To call a tool/function."
1155
+ request: str = "tool_or_function"
1156
+ tool: maybe_optional_type # type: ignore
1157
+
1158
+ def response(self, agent: ChatAgent) -> None | str | ChatDocument:
1159
+ # One-time use
1160
+ agent.set_output_format(None)
1161
+
1162
+ if self.tool is None:
1163
+ return None
1164
+
1165
+ # As the ToolMessage schema accepts invalid
1166
+ # `tool.request` values, reparse with the
1167
+ # corresponding tool
1168
+ request = self.tool.request
1169
+ if request not in agent.llm_tools_map:
1170
+ return None
1171
+ tool = agent.llm_tools_map[request].model_validate_json(
1172
+ self.tool.to_json()
1173
+ )
1174
+
1175
+ return agent.handle_tool_message(tool)
1176
+
1177
+ async def response_async(
1178
+ self, agent: ChatAgent
1179
+ ) -> None | str | ChatDocument:
1180
+ # One-time use
1181
+ agent.set_output_format(None)
1182
+
1183
+ if self.tool is None:
1184
+ return None
1185
+
1186
+ # As the ToolMessage schema accepts invalid
1187
+ # `tool.request` values, reparse with the
1188
+ # corresponding tool
1189
+ request = self.tool.request
1190
+ if request not in agent.llm_tools_map:
1191
+ return None
1192
+ tool = agent.llm_tools_map[request].model_validate_json(
1193
+ self.tool.to_json()
1194
+ )
1195
+
1196
+ return await agent.handle_tool_message_async(tool)
1197
+
1198
+ return AnyTool
1199
+
1200
+ def _strict_recovery_instructions(
1201
+ self,
1202
+ tool_type: Optional[type[ToolMessage]] = None,
1203
+ optional: bool = True,
1204
+ ) -> str:
1205
+ """Returns instructions for strict recovery."""
1206
+ optional_instructions = (
1207
+ (
1208
+ "\n"
1209
+ + """
1210
+ If you did NOT intend to do so, `tool` should be null.
1211
+ """
1212
+ )
1213
+ if optional
1214
+ else ""
1215
+ )
1216
+ response_prefix = "If you intended to make such a call, r" if optional else "R"
1217
+ instruction_prefix = "If you do so, b" if optional else "B"
1218
+
1219
+ schema_instructions = (
1220
+ f"""
1221
+ The schema for `tool_or_function` is as follows:
1222
+ {tool_type.llm_function_schema(defaults=True, request=True).parameters}
1223
+ """
1224
+ if tool_type
1225
+ else ""
1226
+ )
1227
+
1228
+ return textwrap.dedent(
1229
+ f"""
1230
+ Your previous attempt to make a tool/function call appears to have failed.
1231
+ {response_prefix}espond with your desired tool/function. Do so with the
1232
+ `tool_or_function` tool/function where `tool` is set to your intended call.
1233
+ {schema_instructions}
1234
+
1235
+ {instruction_prefix}e sure that your corrected call matches your intention
1236
+ in your previous request. For any field with a default value which
1237
+ you did not intend to override in your previous attempt, be sure
1238
+ to set that field to its default value. {optional_instructions}
1239
+ """
1240
+ )
1241
+
1242
+ def truncate_message(
1243
+ self,
1244
+ idx: int,
1245
+ tokens: int = 5,
1246
+ warning: str = "...[Contents truncated!]",
1247
+ inplace: bool = True,
1248
+ ) -> LLMMessage:
1249
+ """
1250
+ Truncate message at idx in msg history to `tokens` tokens.
1251
+
1252
+ If inplace is True, the message is truncated in place, else
1253
+ it LEAVES the original message INTACT and returns a new message
1254
+ """
1255
+ if inplace:
1256
+ llm_msg = self.message_history[idx]
1257
+ else:
1258
+ llm_msg = copy.deepcopy(self.message_history[idx])
1259
+ orig_content = llm_msg.content
1260
+ new_content = (
1261
+ self.parser.truncate_tokens(orig_content, tokens)
1262
+ if self.parser is not None
1263
+ else orig_content[: tokens * 4] # approx truncation
1264
+ )
1265
+ llm_msg.content = new_content + "\n" + warning
1266
+ return llm_msg
1267
+
1268
+ def _reduce_raw_tool_results(self, message: ChatDocument) -> None:
1269
+ """
1270
+ If message is the result of a ToolMessage that had
1271
+ a `_max_retained_tokens` set to a non-None value, then we replace contents
1272
+ with a placeholder message.
1273
+ """
1274
+ parent_message: ChatDocument | None = message.parent
1275
+ tools = [] if parent_message is None else parent_message.tool_messages
1276
+ truncate_tools = [t for t in tools if t._max_retained_tokens is not None]
1277
+ limiting_tool = truncate_tools[0] if len(truncate_tools) > 0 else None
1278
+ if limiting_tool is not None and limiting_tool._max_retained_tokens is not None:
1279
+ tool_name = limiting_tool.default_value("request")
1280
+ max_tokens: int = limiting_tool._max_retained_tokens
1281
+ truncation_warning = f"""
1282
+ The result of the {tool_name} tool were too large,
1283
+ and has been truncated to {max_tokens} tokens.
1284
+ To obtain the full result, the tool needs to be re-used.
1285
+ """
1286
+ self.truncate_message(
1287
+ message.metadata.msg_idx, max_tokens, truncation_warning
1288
+ )
1289
+
1290
+ def llm_response(
1291
+ self, message: Optional[str | ChatDocument] = None
1292
+ ) -> Optional[ChatDocument]:
1293
+ """
1294
+ Respond to a single user message, appended to the message history,
1295
+ in "chat" mode
1296
+ Args:
1297
+ message (str|ChatDocument): message or ChatDocument object to respond to.
1298
+ If None, use the self.task_messages
1299
+ Returns:
1300
+ LLM response as a ChatDocument object
1301
+ """
1302
+ if self.llm is None:
1303
+ return None
1304
+
1305
+ # If enabled and a tool error occurred, we recover by generating the tool in
1306
+ # strict json mode
1307
+ if (
1308
+ self.tool_error
1309
+ and self.output_format is None
1310
+ and self._json_schema_available()
1311
+ and self.config.strict_recovery
1312
+ ):
1313
+ self.tool_error = False
1314
+ AnyTool = self._get_any_tool_message()
1315
+ if AnyTool is None:
1316
+ return None
1317
+ self.set_output_format(
1318
+ AnyTool,
1319
+ force_tools=True,
1320
+ use=True,
1321
+ handle=True,
1322
+ instructions=True,
1323
+ )
1324
+ recovery_message = self._strict_recovery_instructions(AnyTool)
1325
+ augmented_message = message
1326
+ if augmented_message is None:
1327
+ augmented_message = recovery_message
1328
+ elif isinstance(augmented_message, str):
1329
+ augmented_message = augmented_message + recovery_message
1330
+ else:
1331
+ augmented_message.content = augmented_message.content + recovery_message
1332
+
1333
+ # only use the augmented message for this one response...
1334
+ result = self.llm_response(augmented_message)
1335
+ # ... restore the original user message so that the AnyTool recover
1336
+ # instructions don't persist in the message history
1337
+ # (this can cause the LLM to use the AnyTool directly as a tool)
1338
+ if message is None:
1339
+ self.delete_last_message(role=Role.USER)
1340
+ else:
1341
+ msg = message if isinstance(message, str) else message.content
1342
+ self.update_last_message(msg, role=Role.USER)
1343
+ return result
1344
+
1345
+ hist, output_len = self._prep_llm_messages(message)
1346
+ if len(hist) == 0:
1347
+ return None
1348
+ tool_choice = (
1349
+ "auto"
1350
+ if isinstance(message, str)
1351
+ else (message.oai_tool_choice if message is not None else "auto")
1352
+ )
1353
+ with StreamingIfAllowed(self.llm, self.llm.get_stream()):
1354
+ try:
1355
+ response = self.llm_response_messages(hist, output_len, tool_choice)
1356
+ except openai.BadRequestError as e:
1357
+ if self.any_strict:
1358
+ self.disable_strict = True
1359
+ self.set_output_format(None)
1360
+ logging.warning(
1361
+ f"""
1362
+ OpenAI BadRequestError raised with strict mode enabled.
1363
+ Message: {e.message}
1364
+ Disabling strict mode and retrying.
1365
+ """
1366
+ )
1367
+ return self.llm_response(message)
1368
+ else:
1369
+ raise e
1370
+ self.message_history.extend(ChatDocument.to_LLMMessage(response))
1371
+ response.metadata.msg_idx = len(self.message_history) - 1
1372
+ response.metadata.agent_id = self.id
1373
+ if isinstance(message, ChatDocument):
1374
+ self._reduce_raw_tool_results(message)
1375
+ # Preserve trail of tool_ids for OpenAI Assistant fn-calls
1376
+ response.metadata.tool_ids = (
1377
+ []
1378
+ if isinstance(message, str)
1379
+ else message.metadata.tool_ids if message is not None else []
1380
+ )
1381
+
1382
+ return response
1383
+
1384
+ async def llm_response_async(
1385
+ self, message: Optional[str | ChatDocument] = None
1386
+ ) -> Optional[ChatDocument]:
1387
+ """
1388
+ Async version of `llm_response`. See there for details.
1389
+ """
1390
+ if self.llm is None:
1391
+ return None
1392
+
1393
+ # If enabled and a tool error occurred, we recover by generating the tool in
1394
+ # strict json mode
1395
+ if (
1396
+ self.tool_error
1397
+ and self.output_format is None
1398
+ and self._json_schema_available()
1399
+ and self.config.strict_recovery
1400
+ ):
1401
+ self.tool_error = False
1402
+ AnyTool = self._get_any_tool_message()
1403
+ self.set_output_format(
1404
+ AnyTool,
1405
+ force_tools=True,
1406
+ use=True,
1407
+ handle=True,
1408
+ instructions=True,
1409
+ )
1410
+ recovery_message = self._strict_recovery_instructions(AnyTool)
1411
+ augmented_message = message
1412
+ if augmented_message is None:
1413
+ augmented_message = recovery_message
1414
+ elif isinstance(augmented_message, str):
1415
+ augmented_message = augmented_message + recovery_message
1416
+ else:
1417
+ augmented_message.content = augmented_message.content + recovery_message
1418
+
1419
+ # only use the augmented message for this one response...
1420
+ result = self.llm_response(augmented_message)
1421
+ # ... restore the original user message so that the AnyTool recover
1422
+ # instructions don't persist in the message history
1423
+ # (this can cause the LLM to use the AnyTool directly as a tool)
1424
+ if message is None:
1425
+ self.delete_last_message(role=Role.USER)
1426
+ else:
1427
+ msg = message if isinstance(message, str) else message.content
1428
+ self.update_last_message(msg, role=Role.USER)
1429
+ return result
1430
+
1431
+ hist, output_len = self._prep_llm_messages(message)
1432
+ if len(hist) == 0:
1433
+ return None
1434
+ tool_choice = (
1435
+ "auto"
1436
+ if isinstance(message, str)
1437
+ else (message.oai_tool_choice if message is not None else "auto")
1438
+ )
1439
+ with StreamingIfAllowed(self.llm, self.llm.get_stream()):
1440
+ try:
1441
+ response = await self.llm_response_messages_async(
1442
+ hist, output_len, tool_choice
1443
+ )
1444
+ except openai.BadRequestError as e:
1445
+ if self.any_strict:
1446
+ self.disable_strict = True
1447
+ self.set_output_format(None)
1448
+ logging.warning(
1449
+ f"""
1450
+ OpenAI BadRequestError raised with strict mode enabled.
1451
+ Message: {e.message}
1452
+ Disabling strict mode and retrying.
1453
+ """
1454
+ )
1455
+ return await self.llm_response_async(message)
1456
+ else:
1457
+ raise e
1458
+ self.message_history.extend(ChatDocument.to_LLMMessage(response))
1459
+ response.metadata.msg_idx = len(self.message_history) - 1
1460
+ response.metadata.agent_id = self.id
1461
+ if isinstance(message, ChatDocument):
1462
+ self._reduce_raw_tool_results(message)
1463
+ # Preserve trail of tool_ids for OpenAI Assistant fn-calls
1464
+ response.metadata.tool_ids = (
1465
+ []
1466
+ if isinstance(message, str)
1467
+ else message.metadata.tool_ids if message is not None else []
1468
+ )
1469
+
1470
+ return response
1471
+
1472
+ def init_message_history(self) -> None:
1473
+ """
1474
+ Initialize the message history with the system message and user message
1475
+ """
1476
+ self.message_history = [self._create_system_and_tools_message()]
1477
+ if self.user_message:
1478
+ self.message_history.append(
1479
+ LLMMessage(role=Role.USER, content=self.user_message)
1480
+ )
1481
+
1482
+ def _prep_llm_messages(
1483
+ self,
1484
+ message: Optional[str | ChatDocument] = None,
1485
+ truncate: bool = True,
1486
+ ) -> Tuple[List[LLMMessage], int]:
1487
+ """
1488
+ Prepare messages to be sent to self.llm_response_messages,
1489
+ which is the main method that calls the LLM API to get a response.
1490
+ If desired output tokens + message history exceeds the model context length,
1491
+ then first the max output tokens is reduced to fit, and if that is not
1492
+ possible, older messages may be truncated to accommodate at least
1493
+ self.config.llm.min_output_tokens of output.
1494
+
1495
+ Returns:
1496
+ Tuple[List[LLMMessage], int]: (messages, output_len)
1497
+ messages = Full list of messages to send
1498
+ output_len = max expected number of tokens in response
1499
+ """
1500
+
1501
+ if (
1502
+ not self.llm_can_respond(message)
1503
+ or self.config.llm is None
1504
+ or self.llm is None
1505
+ ):
1506
+ return [], 0
1507
+
1508
+ if message is None and len(self.message_history) > 0:
1509
+ # this means agent has been used to get LLM response already,
1510
+ # and so the last message is an "assistant" response.
1511
+ # We delete this last assistant response and re-generate it.
1512
+ self.clear_history(-1)
1513
+ logger.warning(
1514
+ "Re-generating the last assistant response since message is None"
1515
+ )
1516
+
1517
+ if len(self.message_history) == 0:
1518
+ # initial messages have not yet been loaded, so load them
1519
+ self.init_message_history()
1520
+
1521
+ # for debugging, show the initial message history
1522
+ if settings.debug:
1523
+ print(
1524
+ f"""
1525
+ [grey37]LLM Initial Msg History:
1526
+ {escape(self.message_history_str())}
1527
+ [/grey37]
1528
+ """
1529
+ )
1530
+ else:
1531
+ assert self.message_history[0].role == Role.SYSTEM
1532
+ # update the system message with the latest tool instructions
1533
+ self.message_history[0] = self._create_system_and_tools_message()
1534
+
1535
+ if message is not None:
1536
+ if (
1537
+ isinstance(message, str)
1538
+ or message.id() != self.message_history[-1].chat_document_id
1539
+ ):
1540
+ # either the message is a str, or it is a fresh ChatDocument
1541
+ # different from the last message in the history
1542
+ llm_msgs = ChatDocument.to_LLMMessage(message, self.oai_tool_calls)
1543
+ # LLM only responds to the content, so only those msgs with
1544
+ # non-empty content should be kept
1545
+ llm_msgs = [m for m in llm_msgs if m.content.strip() != ""]
1546
+ if len(llm_msgs) == 0:
1547
+ return [], 0
1548
+ # process tools if any
1549
+ done_tools = [m.tool_call_id for m in llm_msgs if m.role == Role.TOOL]
1550
+ self.oai_tool_calls = [
1551
+ t for t in self.oai_tool_calls if t.id not in done_tools
1552
+ ]
1553
+ self.message_history.extend(llm_msgs)
1554
+
1555
+ hist = self.message_history
1556
+ output_len = self.config.llm.model_max_output_tokens
1557
+ if (
1558
+ truncate
1559
+ and output_len > self.llm.chat_context_length() - self.chat_num_tokens(hist)
1560
+ ):
1561
+ CHAT_HISTORY_BUFFER = 300
1562
+ # chat + output > max context length,
1563
+ # so first try to shorten requested output len to fit;
1564
+ # use an extra margin of CHAT_HISTORY_BUFFER tokens
1565
+ # in case our calcs are off (and to allow for some extra tokens)
1566
+ output_len = (
1567
+ self.llm.chat_context_length()
1568
+ - self.chat_num_tokens(hist)
1569
+ - CHAT_HISTORY_BUFFER
1570
+ )
1571
+ if output_len > self.config.llm.min_output_tokens:
1572
+ logger.debug(
1573
+ f"""
1574
+ Chat Model context length is {self.llm.chat_context_length()},
1575
+ but the current message history is {self.chat_num_tokens(hist)}
1576
+ tokens long, which does not allow
1577
+ {self.config.llm.model_max_output_tokens} output tokens.
1578
+ Therefore we reduced `max_output_tokens` to {output_len} tokens,
1579
+ so they can fit within the model's context length
1580
+ """
1581
+ )
1582
+ else:
1583
+ # unacceptably small output len, so compress early parts of conv
1584
+ # history if output_len is still too long.
1585
+ # TODO we should really be doing summarization or other types of
1586
+ # prompt-size reduction
1587
+ msg_idx_to_compress = 1 # don't touch system msg
1588
+ # we will try compressing msg indices up to but not including
1589
+ # last user msg
1590
+ last_msg_idx_to_compress = (
1591
+ self.last_message_idx_with_role(
1592
+ role=Role.USER,
1593
+ )
1594
+ - 1
1595
+ )
1596
+ n_truncated = 0
1597
+ while (
1598
+ self.chat_num_tokens(hist)
1599
+ > self.llm.chat_context_length() - self.config.llm.min_output_tokens
1600
+ ):
1601
+ # try dropping early parts of conv history
1602
+ # TODO we should really be doing summarization or other types of
1603
+ # prompt-size reduction
1604
+ if msg_idx_to_compress > last_msg_idx_to_compress:
1605
+ # We want to preserve the first message (typically system msg)
1606
+ # and last message (user msg).
1607
+ raise ValueError(
1608
+ """
1609
+ The (message history + max_output_tokens) is longer than the
1610
+ max chat context length of this model, and we have tried
1611
+ reducing the requested max output tokens, as well as truncating
1612
+ early parts of the message history, to accommodate the model
1613
+ context length, but we have run out of msgs to drop.
1614
+
1615
+ HINT: In the `llm` field of your `ChatAgentConfig` object,
1616
+ which is of type `LLMConfig/OpenAIGPTConfig`, try
1617
+ - increasing `chat_context_length`
1618
+ (if accurate for the model), or
1619
+ - decreasing `max_output_tokens`
1620
+ """
1621
+ )
1622
+ n_truncated += 1
1623
+ # compress the msg at idx `msg_idx_to_compress`
1624
+ hist[msg_idx_to_compress] = self.truncate_message(
1625
+ msg_idx_to_compress,
1626
+ tokens=30,
1627
+ warning="... [Contents truncated!]",
1628
+ )
1629
+
1630
+ msg_idx_to_compress += 1
1631
+
1632
+ output_len = min(
1633
+ self.config.llm.model_max_output_tokens,
1634
+ self.llm.chat_context_length()
1635
+ - self.chat_num_tokens(hist)
1636
+ - CHAT_HISTORY_BUFFER,
1637
+ )
1638
+ if output_len < self.config.llm.min_output_tokens:
1639
+ raise ValueError(
1640
+ f"""
1641
+ Tried to shorten prompt history for chat mode
1642
+ but even after truncating all messages except system msg and
1643
+ last (user) msg,
1644
+ the history token len {self.chat_num_tokens(hist)} is
1645
+ too long to accommodate the desired minimum output tokens
1646
+ {self.config.llm.min_output_tokens} within the
1647
+ model's context length {self.llm.chat_context_length()}.
1648
+ Please try shortening the system msg or user prompts,
1649
+ or adjust `config.llm.min_output_tokens` to be smaller.
1650
+ """
1651
+ )
1652
+ else:
1653
+ # we MUST have truncated at least one msg
1654
+ msg_tokens = self.chat_num_tokens()
1655
+ logger.warning(
1656
+ f"""
1657
+ Chat Model context length is {self.llm.chat_context_length()}
1658
+ tokens, but the current message history is {msg_tokens} tokens long,
1659
+ which does not allow {self.config.llm.model_max_output_tokens}
1660
+ output tokens.
1661
+ Therefore we truncated the first {n_truncated} messages
1662
+ in the conversation history so that history token
1663
+ length is reduced to {self.chat_num_tokens(hist)}, and
1664
+ we use `max_output_tokens = {output_len}`,
1665
+ so they can fit within the model's context length
1666
+ of {self.llm.chat_context_length()} tokens.
1667
+ """
1668
+ )
1669
+
1670
+ if isinstance(message, ChatDocument):
1671
+ # record the position of the corresponding LLMMessage in
1672
+ # the message_history
1673
+ message.metadata.msg_idx = len(hist) - 1
1674
+ message.metadata.agent_id = self.id
1675
+ return hist, output_len
1676
+
1677
+ def _function_args(
1678
+ self,
1679
+ ) -> Tuple[
1680
+ Optional[List[LLMFunctionSpec]],
1681
+ str | Dict[str, str],
1682
+ Optional[List[OpenAIToolSpec]],
1683
+ Optional[Dict[str, Dict[str, str] | str]],
1684
+ Optional[OpenAIJsonSchemaSpec],
1685
+ ]:
1686
+ """
1687
+ Get function/tool spec/output format arguments for
1688
+ OpenAI-compatible LLM API call
1689
+ """
1690
+ functions: Optional[List[LLMFunctionSpec]] = None
1691
+ fun_call: str | Dict[str, str] = "none"
1692
+ tools: Optional[List[OpenAIToolSpec]] = None
1693
+ force_tool: Optional[Dict[str, Dict[str, str] | str]] = None
1694
+ self.any_strict = False
1695
+ if self.config.use_functions_api and len(self.llm_functions_usable) > 0:
1696
+ if not self.config.use_tools_api:
1697
+ functions = [
1698
+ self.llm_functions_map[f] for f in self.llm_functions_usable
1699
+ ]
1700
+ fun_call = (
1701
+ "auto"
1702
+ if self.llm_function_force is None
1703
+ else self.llm_function_force
1704
+ )
1705
+ else:
1706
+
1707
+ def to_maybe_strict_spec(function: str) -> OpenAIToolSpec:
1708
+ spec = self.llm_functions_map[function]
1709
+ strict = self._strict_mode_for_tool(function)
1710
+ if strict:
1711
+ self.any_strict = True
1712
+ strict_spec = copy.deepcopy(spec)
1713
+ format_schema_for_strict(strict_spec.parameters)
1714
+ else:
1715
+ strict_spec = spec
1716
+
1717
+ return OpenAIToolSpec(
1718
+ type="function",
1719
+ strict=strict,
1720
+ function=strict_spec,
1721
+ )
1722
+
1723
+ tools = [to_maybe_strict_spec(f) for f in self.llm_functions_usable]
1724
+ force_tool = (
1725
+ None
1726
+ if self.llm_function_force is None
1727
+ else {
1728
+ "type": "function",
1729
+ "function": {"name": self.llm_function_force["name"]},
1730
+ }
1731
+ )
1732
+ output_format = None
1733
+ if self.output_format is not None and self._json_schema_available():
1734
+ self.any_strict = True
1735
+ if issubclass(self.output_format, ToolMessage) and not issubclass(
1736
+ self.output_format, XMLToolMessage
1737
+ ):
1738
+ spec = self.output_format.llm_function_schema(
1739
+ request=True,
1740
+ defaults=self.config.output_format_include_defaults,
1741
+ )
1742
+ format_schema_for_strict(spec.parameters)
1743
+
1744
+ output_format = OpenAIJsonSchemaSpec(
1745
+ # We always require that outputs strictly match the schema
1746
+ strict=True,
1747
+ function=spec,
1748
+ )
1749
+ elif issubclass(self.output_format, BaseModel):
1750
+ param_spec = self.output_format.model_json_schema()
1751
+ format_schema_for_strict(param_spec)
1752
+
1753
+ output_format = OpenAIJsonSchemaSpec(
1754
+ # We always require that outputs strictly match the schema
1755
+ strict=True,
1756
+ function=LLMFunctionSpec(
1757
+ name="json_output",
1758
+ description="Strict Json output format.",
1759
+ parameters=param_spec,
1760
+ ),
1761
+ )
1762
+
1763
+ return functions, fun_call, tools, force_tool, output_format
1764
+
1765
+ def llm_response_messages(
1766
+ self,
1767
+ messages: List[LLMMessage],
1768
+ output_len: Optional[int] = None,
1769
+ tool_choice: ToolChoiceTypes | Dict[str, str | Dict[str, str]] = "auto",
1770
+ ) -> ChatDocument:
1771
+ """
1772
+ Respond to a series of messages, e.g. with OpenAI ChatCompletion
1773
+ Args:
1774
+ messages: seq of messages (with role, content fields) sent to LLM
1775
+ output_len: max number of tokens expected in response.
1776
+ If None, use the LLM's default model_max_output_tokens.
1777
+ Returns:
1778
+ Document (i.e. with fields "content", "metadata")
1779
+ """
1780
+ assert self.config.llm is not None and self.llm is not None
1781
+ output_len = output_len or self.config.llm.model_max_output_tokens
1782
+ streamer = noop_fn
1783
+ if self.llm.get_stream():
1784
+ streamer = self.callbacks.start_llm_stream()
1785
+ self.llm.config.streamer = streamer
1786
+ with ExitStack() as stack: # for conditionally using rich spinner
1787
+ if not self.llm.get_stream() and not settings.quiet:
1788
+ # show rich spinner only if not streaming!
1789
+ # (Why? b/c the intent of showing a spinner is to "show progress",
1790
+ # and we don't need to do that when streaming, since
1791
+ # streaming output already shows progress.)
1792
+ cm = status(
1793
+ "LLM responding to messages...",
1794
+ log_if_quiet=False,
1795
+ )
1796
+ stack.enter_context(cm)
1797
+ if self.llm.get_stream() and not settings.quiet:
1798
+ console.print(f"[green]{self.indent}", end="")
1799
+ functions, fun_call, tools, force_tool, output_format = (
1800
+ self._function_args()
1801
+ )
1802
+ assert self.llm is not None
1803
+ response = self.llm.chat(
1804
+ messages,
1805
+ output_len,
1806
+ tools=tools,
1807
+ tool_choice=force_tool or tool_choice,
1808
+ functions=functions,
1809
+ function_call=fun_call,
1810
+ response_format=output_format,
1811
+ )
1812
+ if self.llm.get_stream():
1813
+ self.callbacks.finish_llm_stream(
1814
+ content=str(response),
1815
+ is_tool=self.has_tool_message_attempt(
1816
+ ChatDocument.from_LLMResponse(response, displayed=True),
1817
+ ),
1818
+ )
1819
+ self.llm.config.streamer = noop_fn
1820
+ if response.cached:
1821
+ self.callbacks.cancel_llm_stream()
1822
+ self._render_llm_response(response)
1823
+ self.update_token_usage(
1824
+ response, # .usage attrib is updated!
1825
+ messages,
1826
+ self.llm.get_stream(),
1827
+ chat=True,
1828
+ print_response_stats=self.config.show_stats and not settings.quiet,
1829
+ )
1830
+ chat_doc = ChatDocument.from_LLMResponse(response, displayed=True)
1831
+ self.oai_tool_calls = response.oai_tool_calls or []
1832
+ self.oai_tool_id2call.model_copy(update=
1833
+ {t.id: t for t in self.oai_tool_calls if t.id is not None}
1834
+ )
1835
+
1836
+ # If using strict output format, parse the output JSON
1837
+ self._load_output_format(chat_doc)
1838
+
1839
+ return chat_doc
1840
+
1841
+ async def llm_response_messages_async(
1842
+ self,
1843
+ messages: List[LLMMessage],
1844
+ output_len: Optional[int] = None,
1845
+ tool_choice: ToolChoiceTypes | Dict[str, str | Dict[str, str]] = "auto",
1846
+ ) -> ChatDocument:
1847
+ """
1848
+ Async version of `llm_response_messages`. See there for details.
1849
+ """
1850
+ assert self.config.llm is not None and self.llm is not None
1851
+ output_len = output_len or self.config.llm.model_max_output_tokens
1852
+ functions, fun_call, tools, force_tool, output_format = self._function_args()
1853
+ assert self.llm is not None
1854
+
1855
+ streamer_async = async_noop_fn
1856
+ if self.llm.get_stream():
1857
+ streamer_async = await self.callbacks.start_llm_stream_async()
1858
+ self.llm.config.streamer_async = streamer_async
1859
+
1860
+ response = await self.llm.achat(
1861
+ messages,
1862
+ output_len,
1863
+ tools=tools,
1864
+ tool_choice=force_tool or tool_choice,
1865
+ functions=functions,
1866
+ function_call=fun_call,
1867
+ response_format=output_format,
1868
+ )
1869
+ if self.llm.get_stream():
1870
+ self.callbacks.finish_llm_stream(
1871
+ content=str(response),
1872
+ is_tool=self.has_tool_message_attempt(
1873
+ ChatDocument.from_LLMResponse(response, displayed=True),
1874
+ ),
1875
+ )
1876
+ self.llm.config.streamer_async = async_noop_fn
1877
+ if response.cached:
1878
+ self.callbacks.cancel_llm_stream()
1879
+ self._render_llm_response(response)
1880
+ self.update_token_usage(
1881
+ response, # .usage attrib is updated!
1882
+ messages,
1883
+ self.llm.get_stream(),
1884
+ chat=True,
1885
+ print_response_stats=self.config.show_stats and not settings.quiet,
1886
+ )
1887
+ chat_doc = ChatDocument.from_LLMResponse(response, displayed=True)
1888
+ self.oai_tool_calls = response.oai_tool_calls or []
1889
+ self.oai_tool_id2call.model_copy(update=
1890
+ {t.id: t for t in self.oai_tool_calls if t.id is not None}
1891
+ )
1892
+
1893
+ # If using strict output format, parse the output JSON
1894
+ self._load_output_format(chat_doc)
1895
+
1896
+ return chat_doc
1897
+
1898
+ def _render_llm_response(
1899
+ self, response: ChatDocument | LLMResponse, citation_only: bool = False
1900
+ ) -> None:
1901
+ is_cached = (
1902
+ response.cached
1903
+ if isinstance(response, LLMResponse)
1904
+ else response.metadata.cached
1905
+ )
1906
+ if self.llm is None:
1907
+ return
1908
+ if not citation_only and (not self.llm.get_stream() or is_cached):
1909
+ # We would have already displayed the msg "live" ONLY if
1910
+ # streaming was enabled, AND we did not find a cached response.
1911
+ # If we are here, it means the response has not yet been displayed.
1912
+ cached = f"[red]{self.indent}(cached)[/red]" if is_cached else ""
1913
+ chat_doc = (
1914
+ response
1915
+ if isinstance(response, ChatDocument)
1916
+ else ChatDocument.from_LLMResponse(response, displayed=True)
1917
+ )
1918
+ # TODO: prepend TOOL: or OAI-TOOL: if it's a tool-call
1919
+ if not settings.quiet:
1920
+ print(cached + "[green]" + escape(str(response)))
1921
+ self.callbacks.show_llm_response(
1922
+ content=str(response),
1923
+ is_tool=self.has_tool_message_attempt(chat_doc),
1924
+ cached=is_cached,
1925
+ )
1926
+ if isinstance(response, LLMResponse):
1927
+ # we are in the context immediately after an LLM responded,
1928
+ # we won't have citations yet, so we're done
1929
+ return
1930
+ if response.metadata.has_citation:
1931
+ citation = (
1932
+ response.metadata.source_content
1933
+ if self.config.full_citations
1934
+ else response.metadata.source
1935
+ )
1936
+ if not settings.quiet:
1937
+ print("[grey37]SOURCES:\n" + escape(citation) + "[/grey37]")
1938
+ self.callbacks.show_llm_response(
1939
+ content=str(citation),
1940
+ is_tool=False,
1941
+ cached=False,
1942
+ language="text",
1943
+ )
1944
+
1945
+ def _llm_response_temp_context(self, message: str, prompt: str) -> ChatDocument:
1946
+ """
1947
+ Get LLM response to `prompt` (which presumably includes the `message`
1948
+ somewhere, along with possible large "context" passages),
1949
+ but only include `message` as the USER message, and not the
1950
+ full `prompt`, in the message history.
1951
+ Args:
1952
+ message: the original, relatively short, user request or query
1953
+ prompt: the full prompt potentially containing `message` plus context
1954
+
1955
+ Returns:
1956
+ Document object containing the response.
1957
+ """
1958
+ # we explicitly call THIS class's respond method,
1959
+ # not a derived class's (or else there would be infinite recursion!)
1960
+ with StreamingIfAllowed(self.llm, self.llm.get_stream()): # type: ignore
1961
+ answer_doc = cast(ChatDocument, ChatAgent.llm_response(self, prompt))
1962
+ self.update_last_message(message, role=Role.USER)
1963
+ return answer_doc
1964
+
1965
+ async def _llm_response_temp_context_async(
1966
+ self, message: str, prompt: str
1967
+ ) -> ChatDocument:
1968
+ """
1969
+ Async version of `_llm_response_temp_context`. See there for details.
1970
+ """
1971
+ # we explicitly call THIS class's respond method,
1972
+ # not a derived class's (or else there would be infinite recursion!)
1973
+ with StreamingIfAllowed(self.llm, self.llm.get_stream()): # type: ignore
1974
+ answer_doc = cast(
1975
+ ChatDocument,
1976
+ await ChatAgent.llm_response_async(self, prompt),
1977
+ )
1978
+ self.update_last_message(message, role=Role.USER)
1979
+ return answer_doc
1980
+
1981
+ def llm_response_forget(
1982
+ self, message: Optional[str | ChatDocument] = None
1983
+ ) -> ChatDocument:
1984
+ """
1985
+ LLM Response to single message, and restore message_history.
1986
+ In effect a "one-off" message & response that leaves agent
1987
+ message history state intact.
1988
+
1989
+ Args:
1990
+ message (str|ChatDocument): message to respond to.
1991
+
1992
+ Returns:
1993
+ A Document object with the response.
1994
+
1995
+ """
1996
+ # explicitly call THIS class's respond method,
1997
+ # not a derived class's (or else there would be infinite recursion!)
1998
+ n_msgs = len(self.message_history)
1999
+ with StreamingIfAllowed(self.llm, self.llm.get_stream()): # type: ignore
2000
+ response = cast(ChatDocument, ChatAgent.llm_response(self, message))
2001
+ # If there is a response, then we will have two additional
2002
+ # messages in the message history, i.e. the user message and the
2003
+ # assistant response. We want to (carefully) remove these two messages.
2004
+ if len(self.message_history) > n_msgs:
2005
+ msg = self.message_history.pop()
2006
+ self._drop_msg_update_tool_calls(msg)
2007
+
2008
+ if len(self.message_history) > n_msgs:
2009
+ msg = self.message_history.pop()
2010
+ self._drop_msg_update_tool_calls(msg)
2011
+
2012
+ # If using strict output format, parse the output JSON
2013
+ self._load_output_format(response)
2014
+
2015
+ return response
2016
+
2017
+ async def llm_response_forget_async(
2018
+ self, message: Optional[str | ChatDocument] = None
2019
+ ) -> ChatDocument:
2020
+ """
2021
+ Async version of `llm_response_forget`. See there for details.
2022
+ """
2023
+ # explicitly call THIS class's respond method,
2024
+ # not a derived class's (or else there would be infinite recursion!)
2025
+ n_msgs = len(self.message_history)
2026
+ with StreamingIfAllowed(self.llm, self.llm.get_stream()): # type: ignore
2027
+ response = cast(
2028
+ ChatDocument, await ChatAgent.llm_response_async(self, message)
2029
+ )
2030
+ # If there is a response, then we will have two additional
2031
+ # messages in the message history, i.e. the user message and the
2032
+ # assistant response. We want to (carefully) remove these two messages.
2033
+ if len(self.message_history) > n_msgs:
2034
+ msg = self.message_history.pop()
2035
+ self._drop_msg_update_tool_calls(msg)
2036
+
2037
+ if len(self.message_history) > n_msgs:
2038
+ msg = self.message_history.pop()
2039
+ self._drop_msg_update_tool_calls(msg)
2040
+ return response
2041
+
2042
+ def chat_num_tokens(self, messages: Optional[List[LLMMessage]] = None) -> int:
2043
+ """
2044
+ Total number of tokens in the message history so far.
2045
+
2046
+ Args:
2047
+ messages: if provided, compute the number of tokens in this list of
2048
+ messages, rather than the current message history.
2049
+ Returns:
2050
+ int: number of tokens in message history
2051
+ """
2052
+ if self.parser is None:
2053
+ raise ValueError(
2054
+ "ChatAgent.parser is None. "
2055
+ "You must set ChatAgent.parser "
2056
+ "before calling chat_num_tokens()."
2057
+ )
2058
+ hist = messages if messages is not None else self.message_history
2059
+ return sum([self.parser.num_tokens(m.content) for m in hist])
2060
+
2061
+ def message_history_str(self, i: Optional[int] = None) -> str:
2062
+ """
2063
+ Return a string representation of the message history
2064
+ Args:
2065
+ i: if provided, return only the i-th message when i is postive,
2066
+ or last k messages when i = -k.
2067
+ Returns:
2068
+ """
2069
+ if i is None:
2070
+ return "\n".join([str(m) for m in self.message_history])
2071
+ elif i > 0:
2072
+ return str(self.message_history[i])
2073
+ else:
2074
+ return "\n".join([str(m) for m in self.message_history[i:]])
2075
+
2076
+ def __del__(self) -> None:
2077
+ """
2078
+ Cleanup method called when the ChatAgent is garbage collected.
2079
+ Note: We don't close LLM clients here because they may be shared
2080
+ across multiple agents when client caching is enabled.
2081
+ The clients are managed centrally and cleaned up via atexit hooks.
2082
+ """
2083
+ # Previously we closed clients here, but this caused issues when
2084
+ # multiple agents shared the same cached client instance.
2085
+ # Clients are now managed centrally in langroid.language_models.client_cache
2086
+ pass