langroid 0.6.7__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langroid/agent/base.py +206 -21
- langroid/agent/callbacks/chainlit.py +1 -1
- langroid/agent/chat_agent.py +124 -29
- langroid/agent/chat_document.py +132 -28
- langroid/agent/openai_assistant.py +8 -3
- langroid/agent/special/sql/sql_chat_agent.py +69 -13
- langroid/agent/task.py +36 -9
- langroid/agent/tool_message.py +8 -5
- langroid/agent/tools/rewind_tool.py +1 -1
- langroid/language_models/.chainlit/config.toml +121 -0
- langroid/language_models/.chainlit/translations/en-US.json +231 -0
- langroid/language_models/base.py +111 -10
- langroid/language_models/mock_lm.py +10 -1
- langroid/language_models/openai_gpt.py +260 -36
- {langroid-0.6.7.dist-info → langroid-0.8.0.dist-info}/METADATA +3 -1
- {langroid-0.6.7.dist-info → langroid-0.8.0.dist-info}/RECORD +19 -17
- pyproject.toml +1 -1
- {langroid-0.6.7.dist-info → langroid-0.8.0.dist-info}/LICENSE +0 -0
- {langroid-0.6.7.dist-info → langroid-0.8.0.dist-info}/WHEEL +0 -0
langroid/agent/base.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
import asyncio
|
2
|
+
import copy
|
2
3
|
import inspect
|
3
4
|
import json
|
4
5
|
import logging
|
@@ -33,6 +34,7 @@ from langroid.language_models.base import (
|
|
33
34
|
LLMMessage,
|
34
35
|
LLMResponse,
|
35
36
|
LLMTokenUsage,
|
37
|
+
OpenAIToolCall,
|
36
38
|
StreamingIfAllowed,
|
37
39
|
)
|
38
40
|
from langroid.language_models.openai_gpt import OpenAIGPT, OpenAIGPTConfig
|
@@ -93,6 +95,11 @@ class Agent(ABC):
|
|
93
95
|
"""
|
94
96
|
|
95
97
|
id: str = Field(default_factory=lambda: ObjectRegistry.new_id())
|
98
|
+
# OpenAI tool-calls awaiting response; update when a tool result with Role.TOOL
|
99
|
+
# is added to self.message_history
|
100
|
+
oai_tool_calls: List[OpenAIToolCall] = []
|
101
|
+
# Index of ALL tool calls generated by the agent
|
102
|
+
oai_tool_id2call: Dict[str, OpenAIToolCall] = {}
|
96
103
|
|
97
104
|
def __init__(self, config: AgentConfig = AgentConfig()):
|
98
105
|
self.config = config
|
@@ -343,30 +350,129 @@ class Agent(ABC):
|
|
343
350
|
)
|
344
351
|
return results
|
345
352
|
if not settings.quiet:
|
353
|
+
results_str = (
|
354
|
+
results if isinstance(results, str) else json.dumps(results, indent=2)
|
355
|
+
)
|
346
356
|
console.print(f"[red]{self.indent}", end="")
|
347
|
-
print(f"[red]Agent: {escape(
|
348
|
-
maybe_json = len(extract_top_level_json(
|
357
|
+
print(f"[red]Agent: {escape(results_str)}")
|
358
|
+
maybe_json = len(extract_top_level_json(results_str)) > 0
|
349
359
|
self.callbacks.show_agent_response(
|
350
|
-
content=
|
360
|
+
content=results_str,
|
351
361
|
language="json" if maybe_json else "text",
|
352
362
|
)
|
353
363
|
sender_name = self.config.name
|
354
364
|
if isinstance(msg, ChatDocument) and msg.function_call is not None:
|
355
365
|
# if result was from handling an LLM `function_call`,
|
356
|
-
# set sender_name to
|
366
|
+
# set sender_name to name of the function_call
|
357
367
|
sender_name = msg.function_call.name
|
358
368
|
|
369
|
+
results_str, id2result, oai_tool_id = self._process_tool_results(
|
370
|
+
results if isinstance(results, str) else "",
|
371
|
+
id2result=None if isinstance(results, str) else results,
|
372
|
+
tool_calls=(msg.oai_tool_calls if isinstance(msg, ChatDocument) else None),
|
373
|
+
)
|
359
374
|
return ChatDocument(
|
360
|
-
content=
|
375
|
+
content=results_str,
|
376
|
+
oai_tool_id2result=id2result,
|
361
377
|
metadata=ChatDocMetaData(
|
362
378
|
source=Entity.AGENT,
|
363
379
|
sender=Entity.AGENT,
|
364
380
|
sender_name=sender_name,
|
381
|
+
oai_tool_id=oai_tool_id,
|
365
382
|
# preserve trail of tool_ids for OpenAI Assistant fn-calls
|
366
383
|
tool_ids=[] if isinstance(msg, str) else msg.metadata.tool_ids,
|
367
384
|
),
|
368
385
|
)
|
369
386
|
|
387
|
+
def _process_tool_results(
|
388
|
+
self,
|
389
|
+
results: str,
|
390
|
+
id2result: Dict[str, str] | None,
|
391
|
+
tool_calls: List[OpenAIToolCall] | None = None,
|
392
|
+
) -> Tuple[str, Dict[str, str] | None, str | None]:
|
393
|
+
"""
|
394
|
+
Process results from a response, based on whether
|
395
|
+
they are results of OpenAI tool-calls from THIS agent.
|
396
|
+
|
397
|
+
Return:
|
398
|
+
- str: The response string
|
399
|
+
- Dict[str,str]|None: A dict of OpenAI tool id -> result, if there are
|
400
|
+
multiple tool results.
|
401
|
+
- str|None: tool_id if we there was a single tool result
|
402
|
+
|
403
|
+
"""
|
404
|
+
id2result_ = copy.deepcopy(id2result) if id2result is not None else None
|
405
|
+
results_str = ""
|
406
|
+
oai_tool_id = None
|
407
|
+
|
408
|
+
if results != "":
|
409
|
+
# in this case ignore id2result
|
410
|
+
assert (
|
411
|
+
id2result is None
|
412
|
+
), "id2result should be None when results is not empty!"
|
413
|
+
results_str = results
|
414
|
+
if len(self.oai_tool_calls) > 0:
|
415
|
+
# There may be multiple tool-calls, but we only have one result,
|
416
|
+
# so we return it as a map of the first tool-call id to the result.
|
417
|
+
assert (
|
418
|
+
len(self.oai_tool_calls) == 1
|
419
|
+
), "There are multiple pending tool-calls, but only one result!"
|
420
|
+
# We record the tool_id of the tool-call that
|
421
|
+
# the result is a response to, so that ChatDocument.to_LLMMessage
|
422
|
+
# can properly set the `tool_call_id` field of the LLMMessage.
|
423
|
+
oai_tool_id = self.oai_tool_calls[0].id
|
424
|
+
elif id2result is not None and id2result_ is not None: # appease mypy
|
425
|
+
assert (
|
426
|
+
tool_calls is not None
|
427
|
+
), "tool_calls cannot be None when id2result is not None!"
|
428
|
+
# This must be an OpenAI tool id -> result map;
|
429
|
+
# However some ids may not correspond to the tool-calls in the list of
|
430
|
+
# pending tool-calls (self.oai_tool_calls). Such results are concatenated
|
431
|
+
# into a simple string, to store in the ChatDocument.content,
|
432
|
+
# and the rest
|
433
|
+
# (i.e. those that DO correspond to tools in self.oai_tool_calls)
|
434
|
+
# are stored as a dict in ChatDocument.oa_tool_id2result.
|
435
|
+
|
436
|
+
# OAI tools from THIS agent, awaiting response
|
437
|
+
pending_tool_ids = [tc.id for tc in self.oai_tool_calls]
|
438
|
+
# tool_calls that the results are a response to
|
439
|
+
# (but these may have been sent from another agent, hence may not be in
|
440
|
+
# self.oai_tool_calls)
|
441
|
+
parent_tool_id2name = {
|
442
|
+
tc.id: tc.function.name
|
443
|
+
for tc in tool_calls or []
|
444
|
+
if tc.function is not None
|
445
|
+
}
|
446
|
+
|
447
|
+
# (id, result) for result NOT corresponding to self.oai_tool_calls,
|
448
|
+
# i.e. these are results of EXTERNAL tool-calls from another agent.
|
449
|
+
external_tool_id_results = []
|
450
|
+
|
451
|
+
for tc_id, result in id2result.items():
|
452
|
+
if tc_id not in pending_tool_ids:
|
453
|
+
external_tool_id_results.append((tc_id, result))
|
454
|
+
id2result_.pop(tc_id)
|
455
|
+
if len(external_tool_id_results) == 0:
|
456
|
+
results_str = ""
|
457
|
+
elif len(external_tool_id_results) == 1:
|
458
|
+
results_str = external_tool_id_results[0][1]
|
459
|
+
else:
|
460
|
+
results_str = "\n\n".join(
|
461
|
+
[
|
462
|
+
f"Result from tool/function {parent_tool_id2name[id]}: {result}"
|
463
|
+
for id, result in external_tool_id_results
|
464
|
+
]
|
465
|
+
)
|
466
|
+
|
467
|
+
if len(id2result_) == 0:
|
468
|
+
id2result_ = None
|
469
|
+
elif len(id2result_) == 1 and len(external_tool_id_results) == 0:
|
470
|
+
results_str = list(id2result_.values())[0]
|
471
|
+
oai_tool_id = list(id2result_.keys())[0]
|
472
|
+
id2result_ = None
|
473
|
+
|
474
|
+
return results_str, id2result_, oai_tool_id
|
475
|
+
|
370
476
|
def _response_template(self, e: Entity, content: str | None = None) -> ChatDocument:
|
371
477
|
"""Template for response from entity `e`."""
|
372
478
|
return ChatDocument(
|
@@ -627,16 +733,25 @@ class Agent(ABC):
|
|
627
733
|
# (either via OpenAI Fn-call or Langroid-native ToolMessage)
|
628
734
|
return msg.tool_messages
|
629
735
|
assert isinstance(msg, ChatDocument)
|
630
|
-
|
631
|
-
|
736
|
+
if (
|
737
|
+
msg.content != ""
|
738
|
+
and msg.oai_tool_calls is None
|
739
|
+
and msg.function_call is None
|
740
|
+
):
|
741
|
+
|
632
742
|
tools = self.get_json_tool_messages(msg.content)
|
633
743
|
msg.tool_messages = tools
|
634
744
|
return tools
|
635
745
|
|
636
|
-
# otherwise, we look for
|
637
|
-
|
638
|
-
tools = [fun_call_cls] if fun_call_cls is not None else []
|
746
|
+
# otherwise, we look for `tool_calls` (possibly multiple)
|
747
|
+
tools = self.get_oai_tool_calls_classes(msg)
|
639
748
|
msg.tool_messages = tools
|
749
|
+
|
750
|
+
if len(tools) == 0:
|
751
|
+
# otherwise, we look for a `function_call`
|
752
|
+
fun_call_cls = self.get_function_call_class(msg)
|
753
|
+
tools = [fun_call_cls] if fun_call_cls is not None else []
|
754
|
+
msg.tool_messages = tools
|
640
755
|
return tools
|
641
756
|
|
642
757
|
def get_json_tool_messages(self, input_str: str) -> List[ToolMessage]:
|
@@ -656,6 +771,10 @@ class Agent(ABC):
|
|
656
771
|
return [r for r in results if r is not None]
|
657
772
|
|
658
773
|
def get_function_call_class(self, msg: ChatDocument) -> Optional[ToolMessage]:
|
774
|
+
"""
|
775
|
+
From ChatDocument (constructed from an LLM Response), get the `ToolMessage`
|
776
|
+
corresponding to the `function_call` if it exists.
|
777
|
+
"""
|
659
778
|
if msg.function_call is None:
|
660
779
|
return None
|
661
780
|
tool_name = msg.function_call.name
|
@@ -677,6 +796,39 @@ class Agent(ABC):
|
|
677
796
|
tool = tool_class.parse_obj(tool_msg)
|
678
797
|
return tool
|
679
798
|
|
799
|
+
def get_oai_tool_calls_classes(self, msg: ChatDocument) -> List[ToolMessage]:
|
800
|
+
"""
|
801
|
+
From ChatDocument (constructed from an LLM Response), get
|
802
|
+
a list of ToolMessages corresponding to the `tool_calls`, if any.
|
803
|
+
"""
|
804
|
+
|
805
|
+
if msg.oai_tool_calls is None:
|
806
|
+
return []
|
807
|
+
tools = []
|
808
|
+
for tc in msg.oai_tool_calls:
|
809
|
+
if tc.function is None:
|
810
|
+
continue
|
811
|
+
tool_name = tc.function.name
|
812
|
+
tool_msg = tc.function.arguments or {}
|
813
|
+
if tool_name not in self.llm_tools_handled:
|
814
|
+
logger.warning(
|
815
|
+
f"""
|
816
|
+
The tool_call '{tool_name}' is not handled
|
817
|
+
by the agent named '{self.config.name}'!
|
818
|
+
If you intended this agent to handle this function_call,
|
819
|
+
either the fn-call name is incorrectly generated by the LLM,
|
820
|
+
(in which case you may need to adjust your LLM instructions),
|
821
|
+
or you need to enable this agent to handle this fn-call.
|
822
|
+
"""
|
823
|
+
)
|
824
|
+
continue
|
825
|
+
tool_class = self.llm_tools_map[tool_name]
|
826
|
+
tool_msg.update(dict(request=tool_name))
|
827
|
+
tool = tool_class.parse_obj(tool_msg)
|
828
|
+
tool.id = tc.id or ""
|
829
|
+
tools.append(tool)
|
830
|
+
return tools
|
831
|
+
|
680
832
|
def tool_validation_error(self, ve: ValidationError) -> str:
|
681
833
|
"""
|
682
834
|
Handle a validation error raised when parsing a tool message,
|
@@ -699,7 +851,9 @@ class Agent(ABC):
|
|
699
851
|
Please write your message again, correcting the errors.
|
700
852
|
"""
|
701
853
|
|
702
|
-
def handle_message(
|
854
|
+
def handle_message(
|
855
|
+
self, msg: str | ChatDocument
|
856
|
+
) -> None | str | Dict[str, str] | ChatDocument:
|
703
857
|
"""
|
704
858
|
Handle a "tool" message either a string containing one or more
|
705
859
|
valid "tool" JSON substrings, or a
|
@@ -711,9 +865,16 @@ class Agent(ABC):
|
|
711
865
|
msg (str | ChatDocument): The string or ChatDocument to handle
|
712
866
|
|
713
867
|
Returns:
|
714
|
-
|
715
|
-
|
716
|
-
|
868
|
+
The result of the handler method can be:
|
869
|
+
- None if no tools successfully handled, or no tools present
|
870
|
+
- str if langroid-native JSON tools were handled, and results concatenated,
|
871
|
+
OR there's a SINGLE OpenAI tool-call.
|
872
|
+
(We do this so the common scenario of a single tool/fn-call
|
873
|
+
has a simple behavior).
|
874
|
+
- Dict[str, str] if multiple OpenAI tool-calls were handled
|
875
|
+
(dict is an id->result map)
|
876
|
+
- ChatDocument if a handler returned a ChatDocument, intended to be the
|
877
|
+
final response of the `agent_response` method.
|
717
878
|
"""
|
718
879
|
try:
|
719
880
|
tools = self.get_tool_messages(msg)
|
@@ -729,14 +890,25 @@ class Agent(ABC):
|
|
729
890
|
return None
|
730
891
|
if len(tools) == 0:
|
731
892
|
return self.handle_message_fallback(msg)
|
732
|
-
|
893
|
+
has_ids = all([t.id != "" for t in tools])
|
733
894
|
results = [self.handle_tool_message(t) for t in tools]
|
734
|
-
|
735
|
-
|
736
|
-
|
895
|
+
tool_names = [t.default_value("request") for t in tools]
|
896
|
+
if has_ids:
|
897
|
+
id2result = {
|
898
|
+
t.id: r
|
899
|
+
for t, r in zip(tools, results)
|
900
|
+
if r is not None and isinstance(r, str)
|
901
|
+
}
|
902
|
+
|
903
|
+
name_results_list = [
|
904
|
+
(name, r) for name, r in zip(tool_names, results) if r is not None
|
905
|
+
]
|
906
|
+
if len(name_results_list) == 0:
|
737
907
|
return None # self.handle_message_fallback(msg)
|
738
908
|
# there was a non-None result
|
739
|
-
chat_doc_results = [
|
909
|
+
chat_doc_results = [
|
910
|
+
r for _, r in name_results_list if isinstance(r, ChatDocument)
|
911
|
+
]
|
740
912
|
if len(chat_doc_results) > 1:
|
741
913
|
logger.warning(
|
742
914
|
"""There were multiple ChatDocument results from tools,
|
@@ -747,8 +919,21 @@ class Agent(ABC):
|
|
747
919
|
if len(chat_doc_results) > 0:
|
748
920
|
return chat_doc_results[0]
|
749
921
|
|
750
|
-
|
751
|
-
|
922
|
+
if has_ids and len(id2result) > 1:
|
923
|
+
# if there are multiple OpenAI Tool results, return them as a dict
|
924
|
+
return id2result
|
925
|
+
|
926
|
+
if len(name_results_list) == 1 and isinstance(name_results_list[0][1], str):
|
927
|
+
# single str result -- return it as is
|
928
|
+
return name_results_list[0][1]
|
929
|
+
|
930
|
+
# multi-results: prepend the tool name to each result
|
931
|
+
str_results = [
|
932
|
+
f"Result from {name}: {r}"
|
933
|
+
for name, r in name_results_list
|
934
|
+
if isinstance(r, str)
|
935
|
+
]
|
936
|
+
final = "\n\n".join(str_results)
|
752
937
|
return final
|
753
938
|
|
754
939
|
def handle_message_fallback(
|
@@ -58,7 +58,7 @@ async def setup_llm() -> None:
|
|
58
58
|
timeout = llm_settings.get("timeout", 90)
|
59
59
|
logger.info(f"Using model: {model}")
|
60
60
|
llm_config = lm.OpenAIGPTConfig(
|
61
|
-
chat_model=model or lm.OpenAIChatModel.
|
61
|
+
chat_model=model or lm.OpenAIChatModel.GPT4o,
|
62
62
|
# or, other possibilities for example:
|
63
63
|
# "litellm/ollama_chat/mistral"
|
64
64
|
# "litellm/ollama_chat/mistral:7b-instruct-v0.2-q8_0"
|
langroid/agent/chat_agent.py
CHANGED
@@ -16,8 +16,10 @@ from langroid.language_models.base import (
|
|
16
16
|
LLMFunctionSpec,
|
17
17
|
LLMMessage,
|
18
18
|
LLMResponse,
|
19
|
+
OpenAIToolSpec,
|
19
20
|
Role,
|
20
21
|
StreamingIfAllowed,
|
22
|
+
ToolChoiceTypes,
|
21
23
|
)
|
22
24
|
from langroid.language_models.openai_gpt import OpenAIGPT
|
23
25
|
from langroid.utils.configuration import settings
|
@@ -39,14 +41,19 @@ class ChatAgentConfig(AgentConfig):
|
|
39
41
|
user_message: user message to include in message sequence.
|
40
42
|
Used only if `task` is not specified in the constructor.
|
41
43
|
use_tools: whether to use our own ToolMessages mechanism
|
42
|
-
use_functions_api: whether to use functions native to the LLM API
|
43
|
-
(e.g. OpenAI's `function_call` mechanism)
|
44
|
+
use_functions_api: whether to use functions/tools native to the LLM API
|
45
|
+
(e.g. OpenAI's `function_call` or `tool_call` mechanism)
|
46
|
+
use_tools_api: When `use_functions_api` is True, if this is also True,
|
47
|
+
the OpenAI tool-call API is used, rather than the older/deprecated
|
48
|
+
function-call API. However the tool-call API has some tricky aspects,
|
49
|
+
hence we set this to False by default.
|
44
50
|
"""
|
45
51
|
|
46
52
|
system_message: str = "You are a helpful assistant."
|
47
53
|
user_message: Optional[str] = None
|
48
54
|
use_tools: bool = False
|
49
55
|
use_functions_api: bool = True
|
56
|
+
use_tools_api: bool = False
|
50
57
|
|
51
58
|
def _set_fn_or_tools(self, fn_available: bool) -> None:
|
52
59
|
"""
|
@@ -205,6 +212,23 @@ class ChatAgent(Agent):
|
|
205
212
|
msgs.append(LLMMessage(role=Role.USER, content=self.user_message))
|
206
213
|
return msgs
|
207
214
|
|
215
|
+
def _drop_msg_update_tool_calls(self, msg: LLMMessage) -> None:
|
216
|
+
id2idx = {t.id: i for i, t in enumerate(self.oai_tool_calls)}
|
217
|
+
if msg.role == Role.TOOL:
|
218
|
+
# dropping tool result, so ADD the corresponding tool-call back
|
219
|
+
# to the list of pending calls!
|
220
|
+
id = msg.tool_call_id
|
221
|
+
if id in self.oai_tool_id2call:
|
222
|
+
self.oai_tool_calls.append(self.oai_tool_id2call[id])
|
223
|
+
elif msg.tool_calls is not None:
|
224
|
+
# dropping a msg with tool-calls, so DROP these from pending list
|
225
|
+
# as well as from id -> call map
|
226
|
+
for tool_call in msg.tool_calls:
|
227
|
+
if tool_call.id in id2idx:
|
228
|
+
self.oai_tool_calls.pop(id2idx[tool_call.id])
|
229
|
+
if tool_call.id in self.oai_tool_id2call:
|
230
|
+
del self.oai_tool_id2call[tool_call.id]
|
231
|
+
|
208
232
|
def clear_history(self, start: int = -2) -> None:
|
209
233
|
"""
|
210
234
|
Clear the message history, starting at the index `start`
|
@@ -218,7 +242,10 @@ class ChatAgent(Agent):
|
|
218
242
|
n = len(self.message_history)
|
219
243
|
start = max(0, n + start)
|
220
244
|
dropped = self.message_history[start:]
|
221
|
-
|
245
|
+
# consider the dropped msgs in REVERSE order, so we are
|
246
|
+
# carefully updating self.oai_tool_calls
|
247
|
+
for msg in reversed(dropped):
|
248
|
+
self._drop_msg_update_tool_calls(msg)
|
222
249
|
# clear out the chat document from the ObjectRegistry
|
223
250
|
ChatDocument.delete_id(msg.chat_document_id)
|
224
251
|
self.message_history = self.message_history[:start]
|
@@ -519,9 +546,14 @@ class ChatAgent(Agent):
|
|
519
546
|
hist, output_len = self._prep_llm_messages(message)
|
520
547
|
if len(hist) == 0:
|
521
548
|
return None
|
549
|
+
tool_choice = (
|
550
|
+
"auto"
|
551
|
+
if isinstance(message, str)
|
552
|
+
else (message.oai_tool_choice if message is not None else "auto")
|
553
|
+
)
|
522
554
|
with StreamingIfAllowed(self.llm, self.llm.get_stream()):
|
523
|
-
response = self.llm_response_messages(hist, output_len)
|
524
|
-
self.message_history.
|
555
|
+
response = self.llm_response_messages(hist, output_len, tool_choice)
|
556
|
+
self.message_history.extend(ChatDocument.to_LLMMessage(response))
|
525
557
|
response.metadata.msg_idx = len(self.message_history) - 1
|
526
558
|
response.metadata.agent_id = self.id
|
527
559
|
# Preserve trail of tool_ids for OpenAI Assistant fn-calls
|
@@ -543,9 +575,16 @@ class ChatAgent(Agent):
|
|
543
575
|
hist, output_len = self._prep_llm_messages(message)
|
544
576
|
if len(hist) == 0:
|
545
577
|
return None
|
578
|
+
tool_choice = (
|
579
|
+
"auto"
|
580
|
+
if isinstance(message, str)
|
581
|
+
else (message.oai_tool_choice if message is not None else "auto")
|
582
|
+
)
|
546
583
|
with StreamingIfAllowed(self.llm, self.llm.get_stream()):
|
547
|
-
response = await self.llm_response_messages_async(
|
548
|
-
|
584
|
+
response = await self.llm_response_messages_async(
|
585
|
+
hist, output_len, tool_choice
|
586
|
+
)
|
587
|
+
self.message_history.extend(ChatDocument.to_LLMMessage(response))
|
549
588
|
response.metadata.msg_idx = len(self.message_history) - 1
|
550
589
|
response.metadata.agent_id = self.id
|
551
590
|
# Preserve trail of tool_ids for OpenAI Assistant fn-calls
|
@@ -622,8 +661,14 @@ class ChatAgent(Agent):
|
|
622
661
|
):
|
623
662
|
# either the message is a str, or it is a fresh ChatDocument
|
624
663
|
# different from the last message in the history
|
625
|
-
|
626
|
-
|
664
|
+
llm_msgs = ChatDocument.to_LLMMessage(message, self.oai_tool_calls)
|
665
|
+
|
666
|
+
# process tools if any
|
667
|
+
done_tools = [m.tool_call_id for m in llm_msgs if m.role == Role.TOOL]
|
668
|
+
self.oai_tool_calls = [
|
669
|
+
t for t in self.oai_tool_calls if t.id not in done_tools
|
670
|
+
]
|
671
|
+
self.message_history.extend(llm_msgs)
|
627
672
|
|
628
673
|
hist = self.message_history
|
629
674
|
output_len = self.config.llm.max_output_tokens
|
@@ -707,18 +752,47 @@ class ChatAgent(Agent):
|
|
707
752
|
|
708
753
|
def _function_args(
|
709
754
|
self,
|
710
|
-
) -> Tuple[
|
755
|
+
) -> Tuple[
|
756
|
+
Optional[List[LLMFunctionSpec]],
|
757
|
+
str | Dict[str, str],
|
758
|
+
Optional[List[OpenAIToolSpec]],
|
759
|
+
Optional[Dict[str, Dict[str, str] | str]],
|
760
|
+
]:
|
761
|
+
"""Get function/tool spec arguments for OpenAI-compatible LLM API call"""
|
711
762
|
functions: Optional[List[LLMFunctionSpec]] = None
|
712
763
|
fun_call: str | Dict[str, str] = "none"
|
764
|
+
tools: Optional[List[OpenAIToolSpec]] = None
|
765
|
+
force_tool: Optional[Dict[str, Dict[str, str] | str]] = None
|
713
766
|
if self.config.use_functions_api and len(self.llm_functions_usable) > 0:
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
767
|
+
if not self.config.use_tools_api:
|
768
|
+
functions = [
|
769
|
+
self.llm_functions_map[f] for f in self.llm_functions_usable
|
770
|
+
]
|
771
|
+
fun_call = (
|
772
|
+
"auto"
|
773
|
+
if self.llm_function_force is None
|
774
|
+
else self.llm_function_force
|
775
|
+
)
|
776
|
+
else:
|
777
|
+
tools = [
|
778
|
+
OpenAIToolSpec(type="function", function=self.llm_functions_map[f])
|
779
|
+
for f in self.llm_functions_usable
|
780
|
+
]
|
781
|
+
force_tool = (
|
782
|
+
None
|
783
|
+
if self.llm_function_force is None
|
784
|
+
else {
|
785
|
+
"type": "function",
|
786
|
+
"function": {"name": self.llm_function_force["name"]},
|
787
|
+
}
|
788
|
+
)
|
789
|
+
return functions, fun_call, tools, force_tool
|
719
790
|
|
720
791
|
def llm_response_messages(
|
721
|
-
self,
|
792
|
+
self,
|
793
|
+
messages: List[LLMMessage],
|
794
|
+
output_len: Optional[int] = None,
|
795
|
+
tool_choice: ToolChoiceTypes | Dict[str, str | Dict[str, str]] = "auto",
|
722
796
|
) -> ChatDocument:
|
723
797
|
"""
|
724
798
|
Respond to a series of messages, e.g. with OpenAI ChatCompletion
|
@@ -748,11 +822,13 @@ class ChatAgent(Agent):
|
|
748
822
|
stack.enter_context(cm)
|
749
823
|
if self.llm.get_stream() and not settings.quiet:
|
750
824
|
console.print(f"[green]{self.indent}", end="")
|
751
|
-
functions, fun_call = self._function_args()
|
825
|
+
functions, fun_call, tools, force_tool = self._function_args()
|
752
826
|
assert self.llm is not None
|
753
827
|
response = self.llm.chat(
|
754
828
|
messages,
|
755
829
|
output_len,
|
830
|
+
tools=tools,
|
831
|
+
tool_choice=force_tool or tool_choice,
|
756
832
|
functions=functions,
|
757
833
|
function_call=fun_call,
|
758
834
|
)
|
@@ -775,23 +851,24 @@ class ChatAgent(Agent):
|
|
775
851
|
print_response_stats=self.config.show_stats and not settings.quiet,
|
776
852
|
)
|
777
853
|
chat_doc = ChatDocument.from_LLMResponse(response, displayed=True)
|
854
|
+
self.oai_tool_calls = response.oai_tool_calls or []
|
855
|
+
self.oai_tool_id2call.update(
|
856
|
+
{t.id: t for t in self.oai_tool_calls if t.id is not None}
|
857
|
+
)
|
778
858
|
return chat_doc
|
779
859
|
|
780
860
|
async def llm_response_messages_async(
|
781
|
-
self,
|
861
|
+
self,
|
862
|
+
messages: List[LLMMessage],
|
863
|
+
output_len: Optional[int] = None,
|
864
|
+
tool_choice: ToolChoiceTypes | Dict[str, str | Dict[str, str]] = "auto",
|
782
865
|
) -> ChatDocument:
|
783
866
|
"""
|
784
867
|
Async version of `llm_response_messages`. See there for details.
|
785
868
|
"""
|
786
869
|
assert self.config.llm is not None and self.llm is not None
|
787
870
|
output_len = output_len or self.config.llm.max_output_tokens
|
788
|
-
functions
|
789
|
-
fun_call: str | Dict[str, str] = "none"
|
790
|
-
if self.config.use_functions_api and len(self.llm_functions_usable) > 0:
|
791
|
-
functions = [self.llm_functions_map[f] for f in self.llm_functions_usable]
|
792
|
-
fun_call = (
|
793
|
-
"auto" if self.llm_function_force is None else self.llm_function_force
|
794
|
-
)
|
871
|
+
functions, fun_call, tools, force_tool = self._function_args()
|
795
872
|
assert self.llm is not None
|
796
873
|
|
797
874
|
streamer = noop_fn
|
@@ -802,6 +879,8 @@ class ChatAgent(Agent):
|
|
802
879
|
response = await self.llm.achat(
|
803
880
|
messages,
|
804
881
|
output_len,
|
882
|
+
tools=tools,
|
883
|
+
tool_choice=force_tool or tool_choice,
|
805
884
|
functions=functions,
|
806
885
|
function_call=fun_call,
|
807
886
|
)
|
@@ -824,6 +903,10 @@ class ChatAgent(Agent):
|
|
824
903
|
print_response_stats=self.config.show_stats and not settings.quiet,
|
825
904
|
)
|
826
905
|
chat_doc = ChatDocument.from_LLMResponse(response, displayed=True)
|
906
|
+
self.oai_tool_calls = response.oai_tool_calls or []
|
907
|
+
self.oai_tool_id2call.update(
|
908
|
+
{t.id: t for t in self.oai_tool_calls if t.id is not None}
|
909
|
+
)
|
827
910
|
return chat_doc
|
828
911
|
|
829
912
|
def _render_llm_response(
|
@@ -847,6 +930,7 @@ class ChatAgent(Agent):
|
|
847
930
|
if isinstance(response, ChatDocument)
|
848
931
|
else ChatDocument.from_LLMResponse(response, displayed=True)
|
849
932
|
)
|
933
|
+
# TODO: prepend TOOL: or OAI-TOOL: if it's a tool-call
|
850
934
|
print(cached + "[green]" + escape(str(response)))
|
851
935
|
self.callbacks.show_llm_response(
|
852
936
|
content=str(response),
|
@@ -923,8 +1007,14 @@ class ChatAgent(Agent):
|
|
923
1007
|
# If there is a response, then we will have two additional
|
924
1008
|
# messages in the message history, i.e. the user message and the
|
925
1009
|
# assistant response. We want to (carefully) remove these two messages.
|
926
|
-
|
927
|
-
|
1010
|
+
if len(self.message_history) > n_msgs:
|
1011
|
+
msg = self.message_history.pop()
|
1012
|
+
self._drop_msg_update_tool_calls(msg)
|
1013
|
+
|
1014
|
+
if len(self.message_history) > n_msgs:
|
1015
|
+
msg = self.message_history.pop()
|
1016
|
+
self._drop_msg_update_tool_calls(msg)
|
1017
|
+
|
928
1018
|
return response
|
929
1019
|
|
930
1020
|
async def llm_response_forget_async(self, message: str) -> ChatDocument:
|
@@ -941,8 +1031,13 @@ class ChatAgent(Agent):
|
|
941
1031
|
# If there is a response, then we will have two additional
|
942
1032
|
# messages in the message history, i.e. the user message and the
|
943
1033
|
# assistant response. We want to (carefully) remove these two messages.
|
944
|
-
|
945
|
-
|
1034
|
+
if len(self.message_history) > n_msgs:
|
1035
|
+
msg = self.message_history.pop()
|
1036
|
+
self._drop_msg_update_tool_calls(msg)
|
1037
|
+
|
1038
|
+
if len(self.message_history) > n_msgs:
|
1039
|
+
msg = self.message_history.pop()
|
1040
|
+
self._drop_msg_update_tool_calls(msg)
|
946
1041
|
return response
|
947
1042
|
|
948
1043
|
def chat_num_tokens(self, messages: Optional[List[LLMMessage]] = None) -> int:
|