langroid 0.6.7__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/agent/base.py +499 -55
- langroid/agent/callbacks/chainlit.py +1 -1
- langroid/agent/chat_agent.py +191 -37
- langroid/agent/chat_document.py +142 -29
- langroid/agent/openai_assistant.py +20 -4
- langroid/agent/special/lance_doc_chat_agent.py +25 -18
- langroid/agent/special/lance_rag/critic_agent.py +37 -5
- langroid/agent/special/lance_rag/query_planner_agent.py +102 -63
- langroid/agent/special/lance_tools.py +10 -2
- langroid/agent/special/sql/sql_chat_agent.py +69 -13
- langroid/agent/task.py +179 -43
- langroid/agent/tool_message.py +19 -7
- langroid/agent/tools/__init__.py +5 -0
- langroid/agent/tools/orchestration.py +216 -0
- langroid/agent/tools/recipient_tool.py +6 -11
- langroid/agent/tools/rewind_tool.py +1 -1
- langroid/agent/typed_task.py +19 -0
- langroid/language_models/.chainlit/config.toml +121 -0
- langroid/language_models/.chainlit/translations/en-US.json +231 -0
- langroid/language_models/base.py +114 -12
- langroid/language_models/mock_lm.py +10 -1
- langroid/language_models/openai_gpt.py +260 -36
- langroid/mytypes.py +0 -1
- langroid/parsing/parse_json.py +19 -2
- langroid/utils/pydantic_utils.py +19 -0
- langroid/vector_store/base.py +3 -1
- langroid/vector_store/lancedb.py +2 -0
- {langroid-0.6.7.dist-info → langroid-0.9.0.dist-info}/METADATA +4 -1
- {langroid-0.6.7.dist-info → langroid-0.9.0.dist-info}/RECORD +32 -33
- pyproject.toml +2 -1
- langroid/agent/special/lance_rag_new/__init__.py +0 -9
- langroid/agent/special/lance_rag_new/critic_agent.py +0 -171
- langroid/agent/special/lance_rag_new/lance_rag_task.py +0 -144
- langroid/agent/special/lance_rag_new/query_planner_agent.py +0 -222
- langroid/agent/team.py +0 -1758
- {langroid-0.6.7.dist-info → langroid-0.9.0.dist-info}/LICENSE +0 -0
- {langroid-0.6.7.dist-info → langroid-0.9.0.dist-info}/WHEEL +0 -0
langroid/agent/base.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1
1
|
import asyncio
|
2
|
+
import copy
|
2
3
|
import inspect
|
3
4
|
import json
|
4
5
|
import logging
|
5
6
|
import re
|
6
7
|
from abc import ABC
|
8
|
+
from collections import OrderedDict
|
7
9
|
from contextlib import ExitStack
|
8
10
|
from types import SimpleNamespace
|
9
11
|
from typing import (
|
@@ -30,10 +32,13 @@ from langroid.agent.tool_message import ToolMessage
|
|
30
32
|
from langroid.language_models.base import (
|
31
33
|
LanguageModel,
|
32
34
|
LLMConfig,
|
35
|
+
LLMFunctionCall,
|
33
36
|
LLMMessage,
|
34
37
|
LLMResponse,
|
35
38
|
LLMTokenUsage,
|
39
|
+
OpenAIToolCall,
|
36
40
|
StreamingIfAllowed,
|
41
|
+
ToolChoiceTypes,
|
37
42
|
)
|
38
43
|
from langroid.language_models.openai_gpt import OpenAIGPT, OpenAIGPTConfig
|
39
44
|
from langroid.mytypes import Entity
|
@@ -42,11 +47,12 @@ from langroid.parsing.parser import Parser, ParsingConfig
|
|
42
47
|
from langroid.prompts.prompts_config import PromptsConfig
|
43
48
|
from langroid.pydantic_v1 import BaseSettings, Field, ValidationError, validator
|
44
49
|
from langroid.utils.configuration import settings
|
45
|
-
from langroid.utils.constants import NO_ANSWER
|
50
|
+
from langroid.utils.constants import DONE, NO_ANSWER, PASS, PASS_TO, SEND_TO
|
46
51
|
from langroid.utils.object_registry import ObjectRegistry
|
47
52
|
from langroid.utils.output import status
|
48
53
|
from langroid.vector_store.base import VectorStore, VectorStoreConfig
|
49
54
|
|
55
|
+
ORCHESTRATION_STRINGS = [DONE, PASS, PASS_TO, SEND_TO]
|
50
56
|
console = Console(quiet=settings.quiet)
|
51
57
|
|
52
58
|
logger = logging.getLogger(__name__)
|
@@ -93,6 +99,11 @@ class Agent(ABC):
|
|
93
99
|
"""
|
94
100
|
|
95
101
|
id: str = Field(default_factory=lambda: ObjectRegistry.new_id())
|
102
|
+
# OpenAI tool-calls awaiting response; update when a tool result with Role.TOOL
|
103
|
+
# is added to self.message_history
|
104
|
+
oai_tool_calls: List[OpenAIToolCall] = []
|
105
|
+
# Index of ALL tool calls generated by the agent
|
106
|
+
oai_tool_id2call: Dict[str, OpenAIToolCall] = {}
|
96
107
|
|
97
108
|
def __init__(self, config: AgentConfig = AgentConfig()):
|
98
109
|
self.config = config
|
@@ -101,9 +112,8 @@ class Agent(ABC):
|
|
101
112
|
self.llm_tools_map: Dict[str, Type[ToolMessage]] = {}
|
102
113
|
self.llm_tools_handled: Set[str] = set()
|
103
114
|
self.llm_tools_usable: Set[str] = set()
|
115
|
+
self.llm_tools_known: Set[str] = set() # all known tools, handled/used or not
|
104
116
|
self.interactive: bool | None = None
|
105
|
-
self.total_llm_token_cost = 0.0
|
106
|
-
self.total_llm_token_usage = 0
|
107
117
|
self.token_stats_str = ""
|
108
118
|
self.default_human_response: Optional[str] = None
|
109
119
|
self._indent = ""
|
@@ -133,6 +143,13 @@ class Agent(ABC):
|
|
133
143
|
show_error_message=noop_fn,
|
134
144
|
show_start_response=noop_fn,
|
135
145
|
)
|
146
|
+
Agent.init_state(self)
|
147
|
+
self.init_state()
|
148
|
+
|
149
|
+
def init_state(self) -> None:
|
150
|
+
"""Initialize all state vars. Called by Task.run() if restart is True"""
|
151
|
+
self.total_llm_token_cost = 0.0
|
152
|
+
self.total_llm_token_usage = 0
|
136
153
|
|
137
154
|
@staticmethod
|
138
155
|
def from_id(id: str) -> "Agent":
|
@@ -232,7 +249,7 @@ class Agent(ABC):
|
|
232
249
|
):
|
233
250
|
"""
|
234
251
|
If the message class has a `handle` method,
|
235
|
-
and does NOT have a method with the same name as the tool,
|
252
|
+
and agent does NOT have a method with the same name as the tool,
|
236
253
|
then we create a method for the agent whose name
|
237
254
|
is the value of `tool`, and whose body is the `handle` method.
|
238
255
|
This removes a separate step of having to define this method
|
@@ -240,13 +257,25 @@ class Agent(ABC):
|
|
240
257
|
in one place, i.e. in the message class.
|
241
258
|
See `tests/main/test_stateless_tool_messages.py` for an example.
|
242
259
|
"""
|
243
|
-
|
260
|
+
has_chat_doc_arg = (
|
261
|
+
len(inspect.signature(message_class.handle).parameters) > 1
|
262
|
+
)
|
263
|
+
if has_chat_doc_arg:
|
264
|
+
setattr(self, tool, lambda obj, chat_doc: obj.handle(chat_doc))
|
265
|
+
else:
|
266
|
+
setattr(self, tool, lambda obj: obj.handle())
|
244
267
|
elif (
|
245
268
|
hasattr(message_class, "response")
|
246
269
|
and inspect.isfunction(message_class.response)
|
247
270
|
and not hasattr(self, tool)
|
248
271
|
):
|
249
|
-
|
272
|
+
has_chat_doc_arg = (
|
273
|
+
len(inspect.signature(message_class.response).parameters) > 2
|
274
|
+
)
|
275
|
+
if has_chat_doc_arg:
|
276
|
+
setattr(self, tool, lambda obj, chat_doc: obj.response(self, chat_doc))
|
277
|
+
else:
|
278
|
+
setattr(self, tool, lambda obj: obj.response(self))
|
250
279
|
|
251
280
|
if hasattr(message_class, "handle_message_fallback") and (
|
252
281
|
inspect.isfunction(message_class.handle_message_fallback)
|
@@ -304,9 +333,27 @@ class Agent(ABC):
|
|
304
333
|
]
|
305
334
|
return "\n\n".join(sample_convo)
|
306
335
|
|
307
|
-
def create_agent_response(
|
336
|
+
def create_agent_response(
|
337
|
+
self,
|
338
|
+
content: str | None = None,
|
339
|
+
tool_messages: List[ToolMessage] = [],
|
340
|
+
oai_tool_calls: Optional[List[OpenAIToolCall]] = None,
|
341
|
+
oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
|
342
|
+
oai_tool_id2result: OrderedDict[str, str] | None = None,
|
343
|
+
function_call: LLMFunctionCall | None = None,
|
344
|
+
recipient: str = "",
|
345
|
+
) -> ChatDocument:
|
308
346
|
"""Template for agent_response."""
|
309
|
-
return self.
|
347
|
+
return self.response_template(
|
348
|
+
Entity.AGENT,
|
349
|
+
content=content,
|
350
|
+
tool_messages=tool_messages,
|
351
|
+
oai_tool_calls=oai_tool_calls,
|
352
|
+
oai_tool_choice=oai_tool_choice,
|
353
|
+
oai_tool_id2result=oai_tool_id2result,
|
354
|
+
function_call=function_call,
|
355
|
+
recipient=recipient,
|
356
|
+
)
|
310
357
|
|
311
358
|
async def agent_response_async(
|
312
359
|
self,
|
@@ -343,43 +390,195 @@ class Agent(ABC):
|
|
343
390
|
)
|
344
391
|
return results
|
345
392
|
if not settings.quiet:
|
393
|
+
results_str = (
|
394
|
+
results if isinstance(results, str) else json.dumps(results, indent=2)
|
395
|
+
)
|
346
396
|
console.print(f"[red]{self.indent}", end="")
|
347
|
-
print(f"[red]Agent: {escape(
|
348
|
-
maybe_json = len(extract_top_level_json(
|
397
|
+
print(f"[red]Agent: {escape(results_str)}")
|
398
|
+
maybe_json = len(extract_top_level_json(results_str)) > 0
|
349
399
|
self.callbacks.show_agent_response(
|
350
|
-
content=
|
400
|
+
content=results_str,
|
351
401
|
language="json" if maybe_json else "text",
|
352
402
|
)
|
353
403
|
sender_name = self.config.name
|
354
404
|
if isinstance(msg, ChatDocument) and msg.function_call is not None:
|
355
405
|
# if result was from handling an LLM `function_call`,
|
356
|
-
# set sender_name to
|
406
|
+
# set sender_name to name of the function_call
|
357
407
|
sender_name = msg.function_call.name
|
358
408
|
|
409
|
+
results_str, id2result, oai_tool_id = self.process_tool_results(
|
410
|
+
results if isinstance(results, str) else "",
|
411
|
+
id2result=None if isinstance(results, str) else results,
|
412
|
+
tool_calls=(msg.oai_tool_calls if isinstance(msg, ChatDocument) else None),
|
413
|
+
)
|
359
414
|
return ChatDocument(
|
360
|
-
content=
|
415
|
+
content=results_str,
|
416
|
+
oai_tool_id2result=id2result,
|
361
417
|
metadata=ChatDocMetaData(
|
362
418
|
source=Entity.AGENT,
|
363
419
|
sender=Entity.AGENT,
|
364
420
|
sender_name=sender_name,
|
421
|
+
oai_tool_id=oai_tool_id,
|
365
422
|
# preserve trail of tool_ids for OpenAI Assistant fn-calls
|
366
423
|
tool_ids=[] if isinstance(msg, str) else msg.metadata.tool_ids,
|
367
424
|
),
|
368
425
|
)
|
369
426
|
|
370
|
-
def
|
427
|
+
def process_tool_results(
|
428
|
+
self,
|
429
|
+
results: str,
|
430
|
+
id2result: OrderedDict[str, str] | None,
|
431
|
+
tool_calls: List[OpenAIToolCall] | None = None,
|
432
|
+
) -> Tuple[str, Dict[str, str] | None, str | None]:
|
433
|
+
"""
|
434
|
+
Process results from a response, based on whether
|
435
|
+
they are results of OpenAI tool-calls from THIS agent, so that
|
436
|
+
we can construct an appropriate LLMMessage that contains tool results.
|
437
|
+
|
438
|
+
Args:
|
439
|
+
results (str): A possible string result from handling tool(s)
|
440
|
+
id2result (OrderedDict[str,str]|None): A dict of OpenAI tool id -> result,
|
441
|
+
if there are multiple tool results.
|
442
|
+
tool_calls (List[OpenAIToolCall]|None): List of OpenAI tool-calls that the
|
443
|
+
results are a response to.
|
444
|
+
|
445
|
+
Return:
|
446
|
+
- str: The response string
|
447
|
+
- Dict[str,str]|None: A dict of OpenAI tool id -> result, if there are
|
448
|
+
multiple tool results.
|
449
|
+
- str|None: tool_id if there was a single tool result
|
450
|
+
|
451
|
+
"""
|
452
|
+
id2result_ = copy.deepcopy(id2result) if id2result is not None else None
|
453
|
+
results_str = ""
|
454
|
+
oai_tool_id = None
|
455
|
+
|
456
|
+
if results != "":
|
457
|
+
# in this case ignore id2result
|
458
|
+
assert (
|
459
|
+
id2result is None
|
460
|
+
), "id2result should be None when results string is non-empty!"
|
461
|
+
results_str = results
|
462
|
+
if len(self.oai_tool_calls) > 0:
|
463
|
+
# We only have one result, so in case there is a
|
464
|
+
# "pending" OpenAI tool-call, we expect no more than 1 such.
|
465
|
+
assert (
|
466
|
+
len(self.oai_tool_calls) == 1
|
467
|
+
), "There are multiple pending tool-calls, but only one result!"
|
468
|
+
# We record the tool_id of the tool-call that
|
469
|
+
# the result is a response to, so that ChatDocument.to_LLMMessage
|
470
|
+
# can properly set the `tool_call_id` field of the LLMMessage.
|
471
|
+
oai_tool_id = self.oai_tool_calls[0].id
|
472
|
+
elif id2result is not None and id2result_ is not None: # appease mypy
|
473
|
+
if len(id2result_) == len(self.oai_tool_calls):
|
474
|
+
# if the number of pending tool calls equals the number of results,
|
475
|
+
# then ignore the ids in id2result, and use the results in order,
|
476
|
+
# which is preserved since id2result is an OrderedDict.
|
477
|
+
assert len(id2result_) > 1, "Expected to see > 1 result in id2result!"
|
478
|
+
results_str = ""
|
479
|
+
id2result_ = OrderedDict(
|
480
|
+
zip(
|
481
|
+
[tc.id or "" for tc in self.oai_tool_calls], id2result_.values()
|
482
|
+
)
|
483
|
+
)
|
484
|
+
else:
|
485
|
+
assert (
|
486
|
+
tool_calls is not None
|
487
|
+
), "tool_calls cannot be None when id2result is not None!"
|
488
|
+
# This must be an OpenAI tool id -> result map;
|
489
|
+
# However some ids may not correspond to the tool-calls in the list of
|
490
|
+
# pending tool-calls (self.oai_tool_calls).
|
491
|
+
# Such results are concatenated into a simple string, to store in the
|
492
|
+
# ChatDocument.content, and the rest
|
493
|
+
# (i.e. those that DO correspond to tools in self.oai_tool_calls)
|
494
|
+
# are stored as a dict in ChatDocument.oai_tool_id2result.
|
495
|
+
|
496
|
+
# OAI tools from THIS agent, awaiting response
|
497
|
+
pending_tool_ids = [tc.id for tc in self.oai_tool_calls]
|
498
|
+
# tool_calls that the results are a response to
|
499
|
+
# (but these may have been sent from another agent, hence may not be in
|
500
|
+
# self.oai_tool_calls)
|
501
|
+
parent_tool_id2name = {
|
502
|
+
tc.id: tc.function.name
|
503
|
+
for tc in tool_calls or []
|
504
|
+
if tc.function is not None
|
505
|
+
}
|
506
|
+
|
507
|
+
# (id, result) for result NOT corresponding to self.oai_tool_calls,
|
508
|
+
# i.e. these are results of EXTERNAL tool-calls from another agent.
|
509
|
+
external_tool_id_results = []
|
510
|
+
|
511
|
+
for tc_id, result in id2result.items():
|
512
|
+
if tc_id not in pending_tool_ids:
|
513
|
+
external_tool_id_results.append((tc_id, result))
|
514
|
+
id2result_.pop(tc_id)
|
515
|
+
if len(external_tool_id_results) == 0:
|
516
|
+
results_str = ""
|
517
|
+
elif len(external_tool_id_results) == 1:
|
518
|
+
results_str = external_tool_id_results[0][1]
|
519
|
+
else:
|
520
|
+
results_str = "\n\n".join(
|
521
|
+
[
|
522
|
+
f"Result from tool/function "
|
523
|
+
f"{parent_tool_id2name[id]}: {result}"
|
524
|
+
for id, result in external_tool_id_results
|
525
|
+
]
|
526
|
+
)
|
527
|
+
|
528
|
+
if len(id2result_) == 0:
|
529
|
+
id2result_ = None
|
530
|
+
elif len(id2result_) == 1 and len(external_tool_id_results) == 0:
|
531
|
+
results_str = list(id2result_.values())[0]
|
532
|
+
oai_tool_id = list(id2result_.keys())[0]
|
533
|
+
id2result_ = None
|
534
|
+
|
535
|
+
return results_str, id2result_, oai_tool_id
|
536
|
+
|
537
|
+
def response_template(
|
538
|
+
self,
|
539
|
+
e: Entity,
|
540
|
+
content: str | None = None,
|
541
|
+
tool_messages: List[ToolMessage] = [],
|
542
|
+
oai_tool_calls: Optional[List[OpenAIToolCall]] = None,
|
543
|
+
oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
|
544
|
+
oai_tool_id2result: OrderedDict[str, str] | None = None,
|
545
|
+
function_call: LLMFunctionCall | None = None,
|
546
|
+
recipient: str = "",
|
547
|
+
) -> ChatDocument:
|
371
548
|
"""Template for response from entity `e`."""
|
372
549
|
return ChatDocument(
|
373
550
|
content=content or "",
|
374
|
-
tool_messages=
|
551
|
+
tool_messages=tool_messages,
|
552
|
+
oai_tool_calls=oai_tool_calls,
|
553
|
+
oai_tool_id2result=oai_tool_id2result,
|
554
|
+
function_call=function_call,
|
555
|
+
oai_tool_choice=oai_tool_choice,
|
375
556
|
metadata=ChatDocMetaData(
|
376
|
-
source=e, sender=e, sender_name=self.config.name,
|
557
|
+
source=e, sender=e, sender_name=self.config.name, recipient=recipient
|
377
558
|
),
|
378
559
|
)
|
379
560
|
|
380
|
-
def create_user_response(
|
561
|
+
def create_user_response(
|
562
|
+
self,
|
563
|
+
content: str | None = None,
|
564
|
+
tool_messages: List[ToolMessage] = [],
|
565
|
+
oai_tool_calls: List[OpenAIToolCall] | None = None,
|
566
|
+
oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
|
567
|
+
oai_tool_id2result: OrderedDict[str, str] | None = None,
|
568
|
+
function_call: LLMFunctionCall | None = None,
|
569
|
+
recipient: str = "",
|
570
|
+
) -> ChatDocument:
|
381
571
|
"""Template for user_response."""
|
382
|
-
return self.
|
572
|
+
return self.response_template(
|
573
|
+
e=Entity.USER,
|
574
|
+
content=content,
|
575
|
+
tool_messages=tool_messages,
|
576
|
+
oai_tool_calls=oai_tool_calls,
|
577
|
+
oai_tool_choice=oai_tool_choice,
|
578
|
+
oai_tool_id2result=oai_tool_id2result,
|
579
|
+
function_call=function_call,
|
580
|
+
recipient=recipient,
|
581
|
+
)
|
383
582
|
|
384
583
|
async def user_response_async(
|
385
584
|
self,
|
@@ -473,9 +672,27 @@ class Agent(ABC):
|
|
473
672
|
|
474
673
|
return True
|
475
674
|
|
476
|
-
def create_llm_response(
|
675
|
+
def create_llm_response(
|
676
|
+
self,
|
677
|
+
content: str | None = None,
|
678
|
+
tool_messages: List[ToolMessage] = [],
|
679
|
+
oai_tool_calls: None | List[OpenAIToolCall] = None,
|
680
|
+
oai_tool_choice: ToolChoiceTypes | Dict[str, Dict[str, str] | str] = "auto",
|
681
|
+
oai_tool_id2result: OrderedDict[str, str] | None = None,
|
682
|
+
function_call: LLMFunctionCall | None = None,
|
683
|
+
recipient: str = "",
|
684
|
+
) -> ChatDocument:
|
477
685
|
"""Template for llm_response."""
|
478
|
-
return self.
|
686
|
+
return self.response_template(
|
687
|
+
Entity.LLM,
|
688
|
+
content=content,
|
689
|
+
tool_messages=tool_messages,
|
690
|
+
oai_tool_calls=oai_tool_calls,
|
691
|
+
oai_tool_choice=oai_tool_choice,
|
692
|
+
oai_tool_id2result=oai_tool_id2result,
|
693
|
+
function_call=function_call,
|
694
|
+
recipient=recipient,
|
695
|
+
)
|
479
696
|
|
480
697
|
@no_type_check
|
481
698
|
async def llm_response_async(
|
@@ -619,25 +836,95 @@ class Agent(ABC):
|
|
619
836
|
return True
|
620
837
|
return False
|
621
838
|
|
622
|
-
def
|
839
|
+
def _tool_recipient_match(self, tool: ToolMessage) -> bool:
|
840
|
+
"""Is tool is handled by this agent
|
841
|
+
and an explicit `recipient` field doesn't preclude this agent from handling it?
|
842
|
+
"""
|
843
|
+
if tool.default_value("request") not in self.llm_tools_handled:
|
844
|
+
return False
|
845
|
+
if hasattr(tool, "recipient") and isinstance(tool.recipient, str):
|
846
|
+
return tool.recipient == "" or tool.recipient == self.config.name
|
847
|
+
return True
|
848
|
+
|
849
|
+
def has_only_unhandled_tools(self, msg: str | ChatDocument) -> bool:
|
850
|
+
"""
|
851
|
+
Does the msg have at least one tool, and ALL tools are
|
852
|
+
disabled for handling by this agent?
|
853
|
+
"""
|
854
|
+
tools = self.get_tool_messages(msg, all_tools=True)
|
855
|
+
if len(tools) == 0:
|
856
|
+
return False
|
857
|
+
return all(not self._tool_recipient_match(t) for t in tools)
|
858
|
+
|
859
|
+
def get_tool_messages(
|
860
|
+
self,
|
861
|
+
msg: str | ChatDocument,
|
862
|
+
all_tools: bool = False,
|
863
|
+
) -> List[ToolMessage]:
|
864
|
+
"""
|
865
|
+
Get ToolMessages recognized in msg, handle-able by this agent.
|
866
|
+
If all_tools is True:
|
867
|
+
- return all tools, i.e. any tool in self.llm_tools_known,
|
868
|
+
whether it is handled by this agent or not;
|
869
|
+
- otherwise, return only the tools handled by this agent.
|
870
|
+
"""
|
871
|
+
|
623
872
|
if isinstance(msg, str):
|
624
|
-
|
873
|
+
json_tools = self.get_json_tool_messages(msg)
|
874
|
+
if all_tools:
|
875
|
+
return json_tools
|
876
|
+
else:
|
877
|
+
return [
|
878
|
+
t
|
879
|
+
for t in json_tools
|
880
|
+
if self._tool_recipient_match(t) and t.default_value("request")
|
881
|
+
]
|
882
|
+
|
883
|
+
if all_tools and len(msg.all_tool_messages) > 0:
|
884
|
+
# We've already identified all_tool_messages in the msg;
|
885
|
+
# return the corresponding ToolMessage objects
|
886
|
+
return msg.all_tool_messages
|
625
887
|
if len(msg.tool_messages) > 0:
|
626
|
-
# We've already found tool_messages
|
627
|
-
# (either via OpenAI Fn-call or Langroid-native ToolMessage)
|
888
|
+
# We've already found tool_messages,
|
889
|
+
# (either via OpenAI Fn-call or Langroid-native ToolMessage);
|
890
|
+
# or they were added by an agent_response.
|
891
|
+
# note these could be from a forwarded msg from another agent,
|
892
|
+
# so return ONLY the messages THIS agent to enabled to handle.
|
628
893
|
return msg.tool_messages
|
629
894
|
assert isinstance(msg, ChatDocument)
|
630
|
-
|
631
|
-
|
895
|
+
if (
|
896
|
+
msg.content != ""
|
897
|
+
and msg.oai_tool_calls is None
|
898
|
+
and msg.function_call is None
|
899
|
+
):
|
900
|
+
|
632
901
|
tools = self.get_json_tool_messages(msg.content)
|
633
|
-
msg.
|
634
|
-
|
902
|
+
msg.all_tool_messages = tools
|
903
|
+
my_tools = [t for t in tools if self._tool_recipient_match(t)]
|
904
|
+
msg.tool_messages = my_tools
|
635
905
|
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
906
|
+
if all_tools:
|
907
|
+
return tools
|
908
|
+
else:
|
909
|
+
return my_tools
|
910
|
+
|
911
|
+
# otherwise, we look for `tool_calls` (possibly multiple)
|
912
|
+
tools = self.get_oai_tool_calls_classes(msg)
|
913
|
+
msg.all_tool_messages = tools
|
914
|
+
my_tools = [t for t in tools if self._tool_recipient_match(t)]
|
915
|
+
msg.tool_messages = my_tools
|
916
|
+
|
917
|
+
if len(tools) == 0:
|
918
|
+
# otherwise, we look for a `function_call`
|
919
|
+
fun_call_cls = self.get_function_call_class(msg)
|
920
|
+
tools = [fun_call_cls] if fun_call_cls is not None else []
|
921
|
+
msg.all_tool_messages = tools
|
922
|
+
my_tools = [t for t in tools if self._tool_recipient_match(t)]
|
923
|
+
msg.tool_messages = my_tools
|
924
|
+
if all_tools:
|
925
|
+
return tools
|
926
|
+
else:
|
927
|
+
return my_tools
|
641
928
|
|
642
929
|
def get_json_tool_messages(self, input_str: str) -> List[ToolMessage]:
|
643
930
|
"""
|
@@ -656,6 +943,10 @@ class Agent(ABC):
|
|
656
943
|
return [r for r in results if r is not None]
|
657
944
|
|
658
945
|
def get_function_call_class(self, msg: ChatDocument) -> Optional[ToolMessage]:
|
946
|
+
"""
|
947
|
+
From ChatDocument (constructed from an LLM Response), get the `ToolMessage`
|
948
|
+
corresponding to the `function_call` if it exists.
|
949
|
+
"""
|
659
950
|
if msg.function_call is None:
|
660
951
|
return None
|
661
952
|
tool_name = msg.function_call.name
|
@@ -677,6 +968,39 @@ class Agent(ABC):
|
|
677
968
|
tool = tool_class.parse_obj(tool_msg)
|
678
969
|
return tool
|
679
970
|
|
971
|
+
def get_oai_tool_calls_classes(self, msg: ChatDocument) -> List[ToolMessage]:
|
972
|
+
"""
|
973
|
+
From ChatDocument (constructed from an LLM Response), get
|
974
|
+
a list of ToolMessages corresponding to the `tool_calls`, if any.
|
975
|
+
"""
|
976
|
+
|
977
|
+
if msg.oai_tool_calls is None:
|
978
|
+
return []
|
979
|
+
tools = []
|
980
|
+
for tc in msg.oai_tool_calls:
|
981
|
+
if tc.function is None:
|
982
|
+
continue
|
983
|
+
tool_name = tc.function.name
|
984
|
+
tool_msg = tc.function.arguments or {}
|
985
|
+
if tool_name not in self.llm_tools_handled:
|
986
|
+
logger.warning(
|
987
|
+
f"""
|
988
|
+
The tool_call '{tool_name}' is not handled
|
989
|
+
by the agent named '{self.config.name}'!
|
990
|
+
If you intended this agent to handle this function_call,
|
991
|
+
either the fn-call name is incorrectly generated by the LLM,
|
992
|
+
(in which case you may need to adjust your LLM instructions),
|
993
|
+
or you need to enable this agent to handle this fn-call.
|
994
|
+
"""
|
995
|
+
)
|
996
|
+
continue
|
997
|
+
tool_class = self.llm_tools_map[tool_name]
|
998
|
+
tool_msg.update(dict(request=tool_name))
|
999
|
+
tool = tool_class.parse_obj(tool_msg)
|
1000
|
+
tool.id = tc.id or ""
|
1001
|
+
tools.append(tool)
|
1002
|
+
return tools
|
1003
|
+
|
680
1004
|
def tool_validation_error(self, ve: ValidationError) -> str:
|
681
1005
|
"""
|
682
1006
|
Handle a validation error raised when parsing a tool message,
|
@@ -699,7 +1023,9 @@ class Agent(ABC):
|
|
699
1023
|
Please write your message again, correcting the errors.
|
700
1024
|
"""
|
701
1025
|
|
702
|
-
def handle_message(
|
1026
|
+
def handle_message(
|
1027
|
+
self, msg: str | ChatDocument
|
1028
|
+
) -> None | str | OrderedDict[str, str] | ChatDocument:
|
703
1029
|
"""
|
704
1030
|
Handle a "tool" message either a string containing one or more
|
705
1031
|
valid "tool" JSON substrings, or a
|
@@ -711,9 +1037,16 @@ class Agent(ABC):
|
|
711
1037
|
msg (str | ChatDocument): The string or ChatDocument to handle
|
712
1038
|
|
713
1039
|
Returns:
|
714
|
-
|
715
|
-
|
716
|
-
|
1040
|
+
The result of the handler method can be:
|
1041
|
+
- None if no tools successfully handled, or no tools present
|
1042
|
+
- str if langroid-native JSON tools were handled, and results concatenated,
|
1043
|
+
OR there's a SINGLE OpenAI tool-call.
|
1044
|
+
(We do this so the common scenario of a single tool/fn-call
|
1045
|
+
has a simple behavior).
|
1046
|
+
- Dict[str, str] if multiple OpenAI tool-calls were handled
|
1047
|
+
(dict is an id->result map)
|
1048
|
+
- ChatDocument if a handler returned a ChatDocument, intended to be the
|
1049
|
+
final response of the `agent_response` method.
|
717
1050
|
"""
|
718
1051
|
try:
|
719
1052
|
tools = self.get_tool_messages(msg)
|
@@ -728,15 +1061,73 @@ class Agent(ABC):
|
|
728
1061
|
# for this agent.
|
729
1062
|
return None
|
730
1063
|
if len(tools) == 0:
|
731
|
-
|
1064
|
+
fallback_result = self.handle_message_fallback(msg)
|
1065
|
+
if fallback_result is not None and isinstance(fallback_result, ToolMessage):
|
1066
|
+
return self.create_agent_response(tool_messages=[fallback_result])
|
1067
|
+
return fallback_result
|
1068
|
+
has_ids = all([t.id != "" for t in tools])
|
1069
|
+
chat_doc = msg if isinstance(msg, ChatDocument) else None
|
1070
|
+
|
1071
|
+
# check whether there are multiple orchestration-tools (e.g. DoneTool etc),
|
1072
|
+
# in which case set result to error-string since we don't yet support
|
1073
|
+
# multi-tools with one or more orch tools.
|
1074
|
+
from langroid.agent.tools.orchestration import (
|
1075
|
+
AgentDoneTool,
|
1076
|
+
AgentSendTool,
|
1077
|
+
DonePassTool,
|
1078
|
+
DoneTool,
|
1079
|
+
ForwardTool,
|
1080
|
+
PassTool,
|
1081
|
+
SendTool,
|
1082
|
+
)
|
1083
|
+
from langroid.agent.tools.recipient_tool import RecipientTool
|
1084
|
+
|
1085
|
+
ORCHESTRATION_TOOLS = (
|
1086
|
+
AgentDoneTool,
|
1087
|
+
DoneTool,
|
1088
|
+
PassTool,
|
1089
|
+
DonePassTool,
|
1090
|
+
ForwardTool,
|
1091
|
+
RecipientTool,
|
1092
|
+
SendTool,
|
1093
|
+
AgentSendTool,
|
1094
|
+
)
|
732
1095
|
|
733
|
-
|
1096
|
+
has_orch = any(isinstance(t, ORCHESTRATION_TOOLS) for t in tools)
|
1097
|
+
results: List[str | ChatDocument | None]
|
1098
|
+
if has_orch and len(tools) > 1:
|
1099
|
+
err_str = "ERROR: Use ONE tool at a time!"
|
1100
|
+
results = [err_str for _ in tools]
|
1101
|
+
else:
|
1102
|
+
results = [self.handle_tool_message(t, chat_doc=chat_doc) for t in tools]
|
1103
|
+
|
1104
|
+
tool_names = [t.default_value("request") for t in tools]
|
1105
|
+
if has_ids:
|
1106
|
+
id2result = OrderedDict(
|
1107
|
+
(t.id, r)
|
1108
|
+
for t, r in zip(tools, results)
|
1109
|
+
if r is not None and isinstance(r, str)
|
1110
|
+
)
|
1111
|
+
result_values = list(id2result.values())
|
1112
|
+
if len(id2result) > 1 and any(
|
1113
|
+
orch_str in r
|
1114
|
+
for r in result_values
|
1115
|
+
for orch_str in ORCHESTRATION_STRINGS
|
1116
|
+
):
|
1117
|
+
# Cannot support multi-tool results containing orchestration strings!
|
1118
|
+
# Replace results with err string to force LLM to retry
|
1119
|
+
err_str = "ERROR: Please use ONE tool at a time!"
|
1120
|
+
id2result = OrderedDict((id, err_str) for id in id2result.keys())
|
734
1121
|
|
735
|
-
|
736
|
-
|
1122
|
+
name_results_list = [
|
1123
|
+
(name, r) for name, r in zip(tool_names, results) if r is not None
|
1124
|
+
]
|
1125
|
+
if len(name_results_list) == 0:
|
737
1126
|
return None # self.handle_message_fallback(msg)
|
738
1127
|
# there was a non-None result
|
739
|
-
chat_doc_results = [
|
1128
|
+
chat_doc_results = [
|
1129
|
+
r for _, r in name_results_list if isinstance(r, ChatDocument)
|
1130
|
+
]
|
740
1131
|
if len(chat_doc_results) > 1:
|
741
1132
|
logger.warning(
|
742
1133
|
"""There were multiple ChatDocument results from tools,
|
@@ -747,27 +1138,49 @@ class Agent(ABC):
|
|
747
1138
|
if len(chat_doc_results) > 0:
|
748
1139
|
return chat_doc_results[0]
|
749
1140
|
|
750
|
-
|
751
|
-
|
1141
|
+
if has_ids and len(id2result) > 1:
|
1142
|
+
# if there are multiple OpenAI Tool results, return them as a dict
|
1143
|
+
return id2result
|
1144
|
+
|
1145
|
+
if len(name_results_list) == 1 and isinstance(name_results_list[0][1], str):
|
1146
|
+
# single str result -- return it as is
|
1147
|
+
return name_results_list[0][1]
|
1148
|
+
|
1149
|
+
# multi-results: prepend the tool name to each result
|
1150
|
+
str_results = [
|
1151
|
+
f"Result from {name}: {r}"
|
1152
|
+
for name, r in name_results_list
|
1153
|
+
if isinstance(r, str)
|
1154
|
+
]
|
1155
|
+
final = "\n\n".join(str_results)
|
752
1156
|
return final
|
753
1157
|
|
754
1158
|
def handle_message_fallback(
|
755
1159
|
self, msg: str | ChatDocument
|
756
1160
|
) -> str | ChatDocument | None:
|
757
1161
|
"""
|
758
|
-
Fallback method
|
759
|
-
|
760
|
-
|
1162
|
+
Fallback method for the "no-tools" scenario.
|
1163
|
+
This method can be overridden by subclasses, e.g.,
|
1164
|
+
to create a "reminder" message when a tool is expected but the LLM "forgot"
|
1165
|
+
to generate one.
|
761
1166
|
|
762
1167
|
Args:
|
763
1168
|
msg (str | ChatDocument): The input msg to handle
|
764
1169
|
Returns:
|
765
1170
|
str: The result of the handler method in string form so it can
|
766
|
-
be
|
1171
|
+
be handled by LLM
|
767
1172
|
"""
|
768
1173
|
return None
|
769
1174
|
|
770
1175
|
def _get_one_tool_message(self, json_str: str) -> Optional[ToolMessage]:
|
1176
|
+
"""
|
1177
|
+
Parse the json str into ANY ToolMessage KNOWN to agent --
|
1178
|
+
This includes non-used/handled tools, i.e. any tool in self.llm_tools_known.
|
1179
|
+
The exception to this is below where we try our best to infer the tool
|
1180
|
+
when the LLM has "forgotten" to include the "request" field in the JSON --
|
1181
|
+
in this case we ONLY look at the possible set of HANDLED tools, i.e.
|
1182
|
+
self.llm_tools_handled.
|
1183
|
+
"""
|
771
1184
|
json_data = json.loads(json_str)
|
772
1185
|
# check if the json_data contains a "properties" field
|
773
1186
|
# which further contains the actual tool-call
|
@@ -791,9 +1204,8 @@ class Agent(ABC):
|
|
791
1204
|
if isinstance(properties, dict):
|
792
1205
|
json_data = properties
|
793
1206
|
request = json_data.get("request")
|
794
|
-
|
795
1207
|
if request is None:
|
796
|
-
|
1208
|
+
possible = [self.llm_tools_map[r] for r in self.llm_tools_handled]
|
797
1209
|
default_keys = set(ToolMessage.__fields__.keys())
|
798
1210
|
request_keys = set(json_data.keys())
|
799
1211
|
|
@@ -817,7 +1229,7 @@ class Agent(ABC):
|
|
817
1229
|
candidate_tools = list(
|
818
1230
|
filter(
|
819
1231
|
lambda t: t is not None,
|
820
|
-
map(maybe_parse,
|
1232
|
+
map(maybe_parse, possible),
|
821
1233
|
)
|
822
1234
|
)
|
823
1235
|
|
@@ -828,7 +1240,7 @@ class Agent(ABC):
|
|
828
1240
|
else:
|
829
1241
|
return None
|
830
1242
|
|
831
|
-
if not isinstance(request, str) or request not in self.
|
1243
|
+
if not isinstance(request, str) or request not in self.llm_tools_known:
|
832
1244
|
return None
|
833
1245
|
|
834
1246
|
message_class = self.llm_tools_map.get(request)
|
@@ -842,11 +1254,18 @@ class Agent(ABC):
|
|
842
1254
|
raise ve
|
843
1255
|
return message
|
844
1256
|
|
845
|
-
def handle_tool_message(
|
1257
|
+
def handle_tool_message(
|
1258
|
+
self,
|
1259
|
+
tool: ToolMessage,
|
1260
|
+
chat_doc: Optional[ChatDocument] = None,
|
1261
|
+
) -> None | str | ChatDocument:
|
846
1262
|
"""
|
847
1263
|
Respond to a tool request from the LLM, in the form of an ToolMessage object.
|
848
1264
|
Args:
|
849
1265
|
tool: ToolMessage object representing the tool request.
|
1266
|
+
chat_doc: Optional ChatDocument object containing the tool request.
|
1267
|
+
This is passed to the tool-handler method only if it has a `chat_doc`
|
1268
|
+
argument.
|
850
1269
|
|
851
1270
|
Returns:
|
852
1271
|
|
@@ -855,9 +1274,34 @@ class Agent(ABC):
|
|
855
1274
|
handler_method = getattr(self, tool_name, None)
|
856
1275
|
if handler_method is None:
|
857
1276
|
return None
|
858
|
-
|
1277
|
+
has_chat_doc_arg = (
|
1278
|
+
chat_doc is not None
|
1279
|
+
and "chat_doc" in inspect.signature(handler_method).parameters
|
1280
|
+
)
|
859
1281
|
try:
|
860
|
-
|
1282
|
+
if has_chat_doc_arg:
|
1283
|
+
maybe_result = handler_method(tool, chat_doc=chat_doc)
|
1284
|
+
else:
|
1285
|
+
maybe_result = handler_method(tool)
|
1286
|
+
if isinstance(maybe_result, ToolMessage):
|
1287
|
+
# result is a ToolMessage, so...
|
1288
|
+
result_tool_name = maybe_result.default_value("request")
|
1289
|
+
if (
|
1290
|
+
result_tool_name in self.llm_tools_handled
|
1291
|
+
and tool_name != result_tool_name
|
1292
|
+
):
|
1293
|
+
# TODO: do we need to remove the tool message from the chat_doc?
|
1294
|
+
# if (chat_doc is not None and
|
1295
|
+
# maybe_result in chat_doc.tool_messages):
|
1296
|
+
# chat_doc.tool_messages.remove(maybe_result)
|
1297
|
+
# if we can handle it, do so
|
1298
|
+
result = self.handle_tool_message(maybe_result, chat_doc=chat_doc)
|
1299
|
+
else:
|
1300
|
+
# else wrap it in an agent response and return it so
|
1301
|
+
# orchestrator can find a respondent
|
1302
|
+
result = self.create_agent_response(tool_messages=[maybe_result])
|
1303
|
+
else:
|
1304
|
+
result = maybe_result
|
861
1305
|
except Exception as e:
|
862
1306
|
# raise the error here since we are sure it's
|
863
1307
|
# not a pydantic validation error,
|